diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 5cea938c..2c33aa7b 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -411,8 +411,8 @@ class Exporter: @try_export def export_openvino(self, prefix=colorstr("OpenVINO:")): """YOLOv8 OpenVINO export.""" - check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/ - import openvino as ov # noqa + check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino/ + import openvino as ov LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...") assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed" @@ -433,7 +433,7 @@ class Exporter: if self.model.task != "classify": ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"]) - ov.save_model(ov_model, file, compress_to_fp16=self.args.half) + ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half) yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml if self.args.int8: diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 0feb0011..4ae96f43 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -135,8 +135,8 @@ class AutoBackend(nn.Module): if not (pt or triton or nn_module): w = attempt_download_asset(w) - # Load model - if nn_module: # in-memory PyTorch model + # In-memory PyTorch model + if nn_module: model = weights.to(device) model = model.fuse(verbose=verbose) if fuse else model if hasattr(model, "kpt_shape"): @@ -146,7 +146,9 @@ class AutoBackend(nn.Module): model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() pt = True - elif pt: # PyTorch + + # PyTorch + elif pt: from ultralytics.nn.tasks import attempt_load_weights model = attempt_load_weights( @@ -158,18 +160,24 @@ class AutoBackend(nn.Module): names = model.module.names if hasattr(model, "module") else model.names # get class names model.half() if fp16 else model.float() self.model = model # explicitly assign for to(), cpu(), cuda(), half() - elif jit: # TorchScript + + # TorchScript + elif jit: LOGGER.info(f"Loading {w} for TorchScript inference...") extra_files = {"config.txt": ""} # model metadata model = torch.jit.load(w, _extra_files=extra_files, map_location=device) model.half() if fp16 else model.float() if extra_files["config.txt"]: # load metadata dict metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) - elif dnn: # ONNX OpenCV DNN + + # ONNX OpenCV DNN + elif dnn: LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") check_requirements("opencv-python>=4.5.4") net = cv2.dnn.readNetFromONNX(w) - elif onnx: # ONNX Runtime + + # ONNX Runtime + elif onnx: LOGGER.info(f"Loading {w} for ONNX Runtime inference...") check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) import onnxruntime @@ -177,11 +185,13 @@ class AutoBackend(nn.Module): providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] session = onnxruntime.InferenceSession(w, providers=providers) output_names = [x.name for x in session.get_outputs()] - metadata = session.get_modelmeta().custom_metadata_map # metadata - elif xml: # OpenVINO + metadata = session.get_modelmeta().custom_metadata_map + + # OpenVINO + elif xml: LOGGER.info(f"Loading {w} for OpenVINO inference...") - check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/ - import openvino as ov # noqa + check_requirements("openvino>=2023.3") + import openvino as ov core = ov.Core() w = Path(w) @@ -193,9 +203,18 @@ class AutoBackend(nn.Module): batch_dim = ov.get_batch(ov_model) if batch_dim.is_static: batch_size = batch_dim.get_length() - ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device + + inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() metadata = w.parent / "metadata.yaml" - elif engine: # TensorRT + + # TensorRT + elif engine: LOGGER.info(f"Loading {w} for TensorRT inference...") try: import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download @@ -234,20 +253,26 @@ class AutoBackend(nn.Module): bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size - elif coreml: # CoreML + + # CoreML + elif coreml: LOGGER.info(f"Loading {w} for CoreML inference...") import coremltools as ct model = ct.models.MLModel(w) metadata = dict(model.user_defined_metadata) - elif saved_model: # TF SavedModel + + # TF SavedModel + elif saved_model: LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") import tensorflow as tf keras = False # assume TF1 saved_model model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) metadata = Path(w) / "metadata.yaml" - elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") import tensorflow as tf @@ -263,6 +288,8 @@ class AutoBackend(nn.Module): with open(w, "rb") as f: gd.ParseFromString(f.read()) frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + + # TFLite or TFLite Edge TPU elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu from tflite_runtime.interpreter import Interpreter, load_delegate @@ -287,9 +314,13 @@ class AutoBackend(nn.Module): with zipfile.ZipFile(w, "r") as model: meta_file = model.namelist()[0] metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) - elif tfjs: # TF.js + + # TF.js + elif tfjs: raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") - elif paddle: # PaddlePaddle + + # PaddlePaddle + elif paddle: LOGGER.info(f"Loading {w} for PaddlePaddle inference...") check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") import paddle.inference as pdi # noqa @@ -304,7 +335,9 @@ class AutoBackend(nn.Module): input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) output_names = predictor.get_output_names() metadata = w.parents[1] / "metadata.yaml" - elif ncnn: # NCNN + + # NCNN + elif ncnn: LOGGER.info(f"Loading {w} for NCNN inference...") check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN import ncnn as pyncnn @@ -317,18 +350,21 @@ class AutoBackend(nn.Module): net.load_param(str(w)) net.load_model(str(w.with_suffix(".bin"))) metadata = w.parent / "metadata.yaml" - elif triton: # NVIDIA Triton Inference Server + + # NVIDIA Triton Inference Server + elif triton: check_requirements("tritonclient[all]") from ultralytics.utils.triton import TritonRemoteModel model = TritonRemoteModel(w) + + # Any other format (unsupported) else: from ultralytics.engine.exporter import export_formats raise TypeError( f"model='{w}' is not a supported model format. " - "See https://docs.ultralytics.com/modes/predict for help." - f"\n\n{export_formats()}" + f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}" ) # Load external metadata YAML @@ -380,21 +416,51 @@ class AutoBackend(nn.Module): if self.nhwc: im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) - if self.pt or self.nn_module: # PyTorch + # PyTorch + if self.pt or self.nn_module: y = self.model(im, augment=augment, visualize=visualize, embed=embed) - elif self.jit: # TorchScript + + # TorchScript + elif self.jit: y = self.model(im) - elif self.dnn: # ONNX OpenCV DNN + + # ONNX OpenCV DNN + elif self.dnn: im = im.cpu().numpy() # torch to numpy self.net.setInput(im) y = self.net.forward() - elif self.onnx: # ONNX Runtime + + # ONNX Runtime + elif self.onnx: im = im.cpu().numpy() # torch to numpy y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) - elif self.xml: # OpenVINO + + # OpenVINO + elif self.xml: im = im.cpu().numpy() # FP32 - y = list(self.ov_compiled_model(im).values()) - elif self.engine: # TensorRT + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = [list(r.values()) for r in results][0] + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: if self.dynamic and im.shape != self.bindings["images"].shape: i = self.model.get_binding_index("images") self.context.set_binding_shape(i, im.shape) # reshape if dynamic @@ -407,7 +473,9 @@ class AutoBackend(nn.Module): self.binding_addrs["images"] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) y = [self.bindings[x].data for x in sorted(self.output_names)] - elif self.coreml: # CoreML + + # CoreML + elif self.coreml: im = im[0].cpu().numpy() im_pil = Image.fromarray((im * 255).astype("uint8")) # im = im.resize((192, 320), Image.BILINEAR) @@ -426,12 +494,16 @@ class AutoBackend(nn.Module): y = list(y.values()) elif len(y) == 2: # segmentation model y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) - elif self.paddle: # PaddlePaddle + + # PaddlePaddle + elif self.paddle: im = im.cpu().numpy().astype(np.float32) self.input_handle.copy_from_cpu(im) self.predictor.run() y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] - elif self.ncnn: # NCNN + + # NCNN + elif self.ncnn: mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) ex = self.net.create_extractor() input_names, output_names = self.net.input_names(), self.net.output_names() @@ -441,10 +513,14 @@ class AutoBackend(nn.Module): mat_out = self.pyncnn.Mat() ex.extract(output_name, mat_out) y.append(np.array(mat_out)[None]) - elif self.triton: # NVIDIA Triton Inference Server + + # NVIDIA Triton Inference Server + elif self.triton: im = im.cpu().numpy() # torch to numpy y = self.model(im) - else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: im = im.cpu().numpy() if self.saved_model: # SavedModel y = self.model(im, training=False) if self.keras else self.model(im)