mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Integrate OpenVINO CUMULATIVE_THROUGHPUT
mode batched inference (#8834)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ff8ae0d2e2
commit
7638c5ce4d
@ -294,11 +294,12 @@ class BasePredictor:
|
|||||||
def setup_model(self, model, verbose=True):
|
def setup_model(self, model, verbose=True):
|
||||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||||
self.model = AutoBackend(
|
self.model = AutoBackend(
|
||||||
model or self.args.model,
|
weights=model or self.args.model,
|
||||||
device=select_device(self.args.device, verbose=verbose),
|
device=select_device(self.args.device, verbose=verbose),
|
||||||
dnn=self.args.dnn,
|
dnn=self.args.dnn,
|
||||||
data=self.args.data,
|
data=self.args.data,
|
||||||
fp16=self.args.half,
|
fp16=self.args.half,
|
||||||
|
batch=self.args.batch,
|
||||||
fuse=True,
|
fuse=True,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ class BaseValidator:
|
|||||||
else:
|
else:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
model = AutoBackend(
|
model = AutoBackend(
|
||||||
model or self.args.model,
|
weights=model or self.args.model,
|
||||||
device=select_device(self.args.device, self.args.batch),
|
device=select_device(self.args.device, self.args.batch),
|
||||||
dnn=self.args.dnn,
|
dnn=self.args.dnn,
|
||||||
data=self.args.data,
|
data=self.args.data,
|
||||||
|
@ -86,6 +86,7 @@ class AutoBackend(nn.Module):
|
|||||||
dnn=False,
|
dnn=False,
|
||||||
data=None,
|
data=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
|
batch=1,
|
||||||
fuse=True,
|
fuse=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
):
|
):
|
||||||
@ -98,6 +99,7 @@ class AutoBackend(nn.Module):
|
|||||||
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
|
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
|
||||||
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
|
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
|
||||||
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
|
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
|
||||||
|
batch (int): Batch-size to assume for inference.
|
||||||
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
|
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
|
||||||
verbose (bool): Enable verbose logging. Defaults to True.
|
verbose (bool): Enable verbose logging. Defaults to True.
|
||||||
"""
|
"""
|
||||||
@ -204,7 +206,9 @@ class AutoBackend(nn.Module):
|
|||||||
if batch_dim.is_static:
|
if batch_dim.is_static:
|
||||||
batch_size = batch_dim.get_length()
|
batch_size = batch_dim.get_length()
|
||||||
|
|
||||||
inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
|
# OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
|
||||||
|
inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY"
|
||||||
|
LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch-size={batch_size} inference...")
|
||||||
ov_compiled_model = core.compile_model(
|
ov_compiled_model = core.compile_model(
|
||||||
ov_model,
|
ov_model,
|
||||||
device_name="AUTO", # AUTO selects best available device, do not modify
|
device_name="AUTO", # AUTO selects best available device, do not modify
|
||||||
@ -454,7 +458,7 @@ class AutoBackend(nn.Module):
|
|||||||
# Start async inference with userdata=i to specify the position in results list
|
# Start async inference with userdata=i to specify the position in results list
|
||||||
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
|
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
|
||||||
async_queue.wait_all() # wait for all inference requests to complete
|
async_queue.wait_all() # wait for all inference requests to complete
|
||||||
y = [list(r.values()) for r in results][0]
|
y = np.concatenate([list(r.values())[0] for r in results])
|
||||||
|
|
||||||
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
|
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
|
||||||
y = list(self.ov_compiled_model(im).values())
|
y = list(self.ov_compiled_model(im).values())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user