mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.1.25
OpenVINO LATENCY
and THROUGHPUT
modes (#8058)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Adrian Boguszewski <adekboguszewski@gmail.com>
This commit is contained in:
parent
6da7c9fb21
commit
90943946a7
@ -411,8 +411,8 @@ class Exporter:
|
|||||||
@try_export
|
@try_export
|
||||||
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
||||||
"""YOLOv8 OpenVINO export."""
|
"""YOLOv8 OpenVINO export."""
|
||||||
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
|
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino/
|
||||||
import openvino as ov # noqa
|
import openvino as ov
|
||||||
|
|
||||||
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
||||||
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
||||||
@ -433,7 +433,7 @@ class Exporter:
|
|||||||
if self.model.task != "classify":
|
if self.model.task != "classify":
|
||||||
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
|
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
|
||||||
|
|
||||||
ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
|
ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half)
|
||||||
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
|
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
|
||||||
|
|
||||||
if self.args.int8:
|
if self.args.int8:
|
||||||
|
@ -135,8 +135,8 @@ class AutoBackend(nn.Module):
|
|||||||
if not (pt or triton or nn_module):
|
if not (pt or triton or nn_module):
|
||||||
w = attempt_download_asset(w)
|
w = attempt_download_asset(w)
|
||||||
|
|
||||||
# Load model
|
# In-memory PyTorch model
|
||||||
if nn_module: # in-memory PyTorch model
|
if nn_module:
|
||||||
model = weights.to(device)
|
model = weights.to(device)
|
||||||
model = model.fuse(verbose=verbose) if fuse else model
|
model = model.fuse(verbose=verbose) if fuse else model
|
||||||
if hasattr(model, "kpt_shape"):
|
if hasattr(model, "kpt_shape"):
|
||||||
@ -146,7 +146,9 @@ class AutoBackend(nn.Module):
|
|||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
pt = True
|
pt = True
|
||||||
elif pt: # PyTorch
|
|
||||||
|
# PyTorch
|
||||||
|
elif pt:
|
||||||
from ultralytics.nn.tasks import attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_weights
|
||||||
|
|
||||||
model = attempt_load_weights(
|
model = attempt_load_weights(
|
||||||
@ -158,18 +160,24 @@ class AutoBackend(nn.Module):
|
|||||||
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
elif jit: # TorchScript
|
|
||||||
|
# TorchScript
|
||||||
|
elif jit:
|
||||||
LOGGER.info(f"Loading {w} for TorchScript inference...")
|
LOGGER.info(f"Loading {w} for TorchScript inference...")
|
||||||
extra_files = {"config.txt": ""} # model metadata
|
extra_files = {"config.txt": ""} # model metadata
|
||||||
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
if extra_files["config.txt"]: # load metadata dict
|
if extra_files["config.txt"]: # load metadata dict
|
||||||
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
|
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
|
||||||
elif dnn: # ONNX OpenCV DNN
|
|
||||||
|
# ONNX OpenCV DNN
|
||||||
|
elif dnn:
|
||||||
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
|
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
|
||||||
check_requirements("opencv-python>=4.5.4")
|
check_requirements("opencv-python>=4.5.4")
|
||||||
net = cv2.dnn.readNetFromONNX(w)
|
net = cv2.dnn.readNetFromONNX(w)
|
||||||
elif onnx: # ONNX Runtime
|
|
||||||
|
# ONNX Runtime
|
||||||
|
elif onnx:
|
||||||
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
|
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
|
||||||
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
@ -177,11 +185,13 @@ class AutoBackend(nn.Module):
|
|||||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
|
||||||
session = onnxruntime.InferenceSession(w, providers=providers)
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
||||||
output_names = [x.name for x in session.get_outputs()]
|
output_names = [x.name for x in session.get_outputs()]
|
||||||
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
metadata = session.get_modelmeta().custom_metadata_map
|
||||||
elif xml: # OpenVINO
|
|
||||||
|
# OpenVINO
|
||||||
|
elif xml:
|
||||||
LOGGER.info(f"Loading {w} for OpenVINO inference...")
|
LOGGER.info(f"Loading {w} for OpenVINO inference...")
|
||||||
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
|
check_requirements("openvino>=2023.3")
|
||||||
import openvino as ov # noqa
|
import openvino as ov
|
||||||
|
|
||||||
core = ov.Core()
|
core = ov.Core()
|
||||||
w = Path(w)
|
w = Path(w)
|
||||||
@ -193,9 +203,18 @@ class AutoBackend(nn.Module):
|
|||||||
batch_dim = ov.get_batch(ov_model)
|
batch_dim = ov.get_batch(ov_model)
|
||||||
if batch_dim.is_static:
|
if batch_dim.is_static:
|
||||||
batch_size = batch_dim.get_length()
|
batch_size = batch_dim.get_length()
|
||||||
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device
|
|
||||||
|
inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
|
||||||
|
ov_compiled_model = core.compile_model(
|
||||||
|
ov_model,
|
||||||
|
device_name="AUTO", # AUTO selects best available device, do not modify
|
||||||
|
config={"PERFORMANCE_HINT": inference_mode},
|
||||||
|
)
|
||||||
|
input_name = ov_compiled_model.input().get_any_name()
|
||||||
metadata = w.parent / "metadata.yaml"
|
metadata = w.parent / "metadata.yaml"
|
||||||
elif engine: # TensorRT
|
|
||||||
|
# TensorRT
|
||||||
|
elif engine:
|
||||||
LOGGER.info(f"Loading {w} for TensorRT inference...")
|
LOGGER.info(f"Loading {w} for TensorRT inference...")
|
||||||
try:
|
try:
|
||||||
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
@ -234,20 +253,26 @@ class AutoBackend(nn.Module):
|
|||||||
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||||
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
|
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
|
||||||
elif coreml: # CoreML
|
|
||||||
|
# CoreML
|
||||||
|
elif coreml:
|
||||||
LOGGER.info(f"Loading {w} for CoreML inference...")
|
LOGGER.info(f"Loading {w} for CoreML inference...")
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
|
|
||||||
model = ct.models.MLModel(w)
|
model = ct.models.MLModel(w)
|
||||||
metadata = dict(model.user_defined_metadata)
|
metadata = dict(model.user_defined_metadata)
|
||||||
elif saved_model: # TF SavedModel
|
|
||||||
|
# TF SavedModel
|
||||||
|
elif saved_model:
|
||||||
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
keras = False # assume TF1 saved_model
|
keras = False # assume TF1 saved_model
|
||||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||||
metadata = Path(w) / "metadata.yaml"
|
metadata = Path(w) / "metadata.yaml"
|
||||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
|
||||||
|
# TF GraphDef
|
||||||
|
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -263,6 +288,8 @@ class AutoBackend(nn.Module):
|
|||||||
with open(w, "rb") as f:
|
with open(w, "rb") as f:
|
||||||
gd.ParseFromString(f.read())
|
gd.ParseFromString(f.read())
|
||||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||||
|
|
||||||
|
# TFLite or TFLite Edge TPU
|
||||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||||
@ -287,9 +314,13 @@ class AutoBackend(nn.Module):
|
|||||||
with zipfile.ZipFile(w, "r") as model:
|
with zipfile.ZipFile(w, "r") as model:
|
||||||
meta_file = model.namelist()[0]
|
meta_file = model.namelist()[0]
|
||||||
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||||
elif tfjs: # TF.js
|
|
||||||
|
# TF.js
|
||||||
|
elif tfjs:
|
||||||
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
|
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
|
||||||
elif paddle: # PaddlePaddle
|
|
||||||
|
# PaddlePaddle
|
||||||
|
elif paddle:
|
||||||
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
|
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
|
||||||
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
|
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
|
||||||
import paddle.inference as pdi # noqa
|
import paddle.inference as pdi # noqa
|
||||||
@ -304,7 +335,9 @@ class AutoBackend(nn.Module):
|
|||||||
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||||
output_names = predictor.get_output_names()
|
output_names = predictor.get_output_names()
|
||||||
metadata = w.parents[1] / "metadata.yaml"
|
metadata = w.parents[1] / "metadata.yaml"
|
||||||
elif ncnn: # NCNN
|
|
||||||
|
# NCNN
|
||||||
|
elif ncnn:
|
||||||
LOGGER.info(f"Loading {w} for NCNN inference...")
|
LOGGER.info(f"Loading {w} for NCNN inference...")
|
||||||
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
|
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
|
||||||
import ncnn as pyncnn
|
import ncnn as pyncnn
|
||||||
@ -317,18 +350,21 @@ class AutoBackend(nn.Module):
|
|||||||
net.load_param(str(w))
|
net.load_param(str(w))
|
||||||
net.load_model(str(w.with_suffix(".bin")))
|
net.load_model(str(w.with_suffix(".bin")))
|
||||||
metadata = w.parent / "metadata.yaml"
|
metadata = w.parent / "metadata.yaml"
|
||||||
elif triton: # NVIDIA Triton Inference Server
|
|
||||||
|
# NVIDIA Triton Inference Server
|
||||||
|
elif triton:
|
||||||
check_requirements("tritonclient[all]")
|
check_requirements("tritonclient[all]")
|
||||||
from ultralytics.utils.triton import TritonRemoteModel
|
from ultralytics.utils.triton import TritonRemoteModel
|
||||||
|
|
||||||
model = TritonRemoteModel(w)
|
model = TritonRemoteModel(w)
|
||||||
|
|
||||||
|
# Any other format (unsupported)
|
||||||
else:
|
else:
|
||||||
from ultralytics.engine.exporter import export_formats
|
from ultralytics.engine.exporter import export_formats
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"model='{w}' is not a supported model format. "
|
f"model='{w}' is not a supported model format. "
|
||||||
"See https://docs.ultralytics.com/modes/predict for help."
|
f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}"
|
||||||
f"\n\n{export_formats()}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load external metadata YAML
|
# Load external metadata YAML
|
||||||
@ -380,21 +416,51 @@ class AutoBackend(nn.Module):
|
|||||||
if self.nhwc:
|
if self.nhwc:
|
||||||
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||||
|
|
||||||
if self.pt or self.nn_module: # PyTorch
|
# PyTorch
|
||||||
|
if self.pt or self.nn_module:
|
||||||
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
|
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
|
||||||
elif self.jit: # TorchScript
|
|
||||||
|
# TorchScript
|
||||||
|
elif self.jit:
|
||||||
y = self.model(im)
|
y = self.model(im)
|
||||||
elif self.dnn: # ONNX OpenCV DNN
|
|
||||||
|
# ONNX OpenCV DNN
|
||||||
|
elif self.dnn:
|
||||||
im = im.cpu().numpy() # torch to numpy
|
im = im.cpu().numpy() # torch to numpy
|
||||||
self.net.setInput(im)
|
self.net.setInput(im)
|
||||||
y = self.net.forward()
|
y = self.net.forward()
|
||||||
elif self.onnx: # ONNX Runtime
|
|
||||||
|
# ONNX Runtime
|
||||||
|
elif self.onnx:
|
||||||
im = im.cpu().numpy() # torch to numpy
|
im = im.cpu().numpy() # torch to numpy
|
||||||
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
||||||
elif self.xml: # OpenVINO
|
|
||||||
|
# OpenVINO
|
||||||
|
elif self.xml:
|
||||||
im = im.cpu().numpy() # FP32
|
im = im.cpu().numpy() # FP32
|
||||||
y = list(self.ov_compiled_model(im).values())
|
|
||||||
elif self.engine: # TensorRT
|
if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
|
||||||
|
n = im.shape[0] # number of images in batch
|
||||||
|
results = [None] * n # preallocate list with None to match the number of images
|
||||||
|
|
||||||
|
def callback(request, userdata):
|
||||||
|
"""Places result in preallocated list using userdata index."""
|
||||||
|
results[userdata] = request.results
|
||||||
|
|
||||||
|
# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
|
||||||
|
async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model)
|
||||||
|
async_queue.set_callback(callback)
|
||||||
|
for i in range(n):
|
||||||
|
# Start async inference with userdata=i to specify the position in results list
|
||||||
|
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
|
||||||
|
async_queue.wait_all() # wait for all inference requests to complete
|
||||||
|
y = [list(r.values()) for r in results][0]
|
||||||
|
|
||||||
|
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
|
||||||
|
y = list(self.ov_compiled_model(im).values())
|
||||||
|
|
||||||
|
# TensorRT
|
||||||
|
elif self.engine:
|
||||||
if self.dynamic and im.shape != self.bindings["images"].shape:
|
if self.dynamic and im.shape != self.bindings["images"].shape:
|
||||||
i = self.model.get_binding_index("images")
|
i = self.model.get_binding_index("images")
|
||||||
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
||||||
@ -407,7 +473,9 @@ class AutoBackend(nn.Module):
|
|||||||
self.binding_addrs["images"] = int(im.data_ptr())
|
self.binding_addrs["images"] = int(im.data_ptr())
|
||||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||||
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
||||||
elif self.coreml: # CoreML
|
|
||||||
|
# CoreML
|
||||||
|
elif self.coreml:
|
||||||
im = im[0].cpu().numpy()
|
im = im[0].cpu().numpy()
|
||||||
im_pil = Image.fromarray((im * 255).astype("uint8"))
|
im_pil = Image.fromarray((im * 255).astype("uint8"))
|
||||||
# im = im.resize((192, 320), Image.BILINEAR)
|
# im = im.resize((192, 320), Image.BILINEAR)
|
||||||
@ -426,12 +494,16 @@ class AutoBackend(nn.Module):
|
|||||||
y = list(y.values())
|
y = list(y.values())
|
||||||
elif len(y) == 2: # segmentation model
|
elif len(y) == 2: # segmentation model
|
||||||
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
||||||
elif self.paddle: # PaddlePaddle
|
|
||||||
|
# PaddlePaddle
|
||||||
|
elif self.paddle:
|
||||||
im = im.cpu().numpy().astype(np.float32)
|
im = im.cpu().numpy().astype(np.float32)
|
||||||
self.input_handle.copy_from_cpu(im)
|
self.input_handle.copy_from_cpu(im)
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
||||||
elif self.ncnn: # NCNN
|
|
||||||
|
# NCNN
|
||||||
|
elif self.ncnn:
|
||||||
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
|
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
|
||||||
ex = self.net.create_extractor()
|
ex = self.net.create_extractor()
|
||||||
input_names, output_names = self.net.input_names(), self.net.output_names()
|
input_names, output_names = self.net.input_names(), self.net.output_names()
|
||||||
@ -441,10 +513,14 @@ class AutoBackend(nn.Module):
|
|||||||
mat_out = self.pyncnn.Mat()
|
mat_out = self.pyncnn.Mat()
|
||||||
ex.extract(output_name, mat_out)
|
ex.extract(output_name, mat_out)
|
||||||
y.append(np.array(mat_out)[None])
|
y.append(np.array(mat_out)[None])
|
||||||
elif self.triton: # NVIDIA Triton Inference Server
|
|
||||||
|
# NVIDIA Triton Inference Server
|
||||||
|
elif self.triton:
|
||||||
im = im.cpu().numpy() # torch to numpy
|
im = im.cpu().numpy() # torch to numpy
|
||||||
y = self.model(im)
|
y = self.model(im)
|
||||||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
|
||||||
|
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||||
|
else:
|
||||||
im = im.cpu().numpy()
|
im = im.cpu().numpy()
|
||||||
if self.saved_model: # SavedModel
|
if self.saved_model: # SavedModel
|
||||||
y = self.model(im, training=False) if self.keras else self.model(im)
|
y = self.model(im, training=False) if self.keras else self.model(im)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user