Integrate OpenVINO CUMULATIVE_THROUGHPUT mode batched inference (#8834)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-11 18:50:29 +01:00 committed by GitHub
parent ff8ae0d2e2
commit 7638c5ce4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -294,11 +294,12 @@ class BasePredictor:
def setup_model(self, model, verbose=True): def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode.""" """Initialize YOLO model with given parameters and set it to evaluation mode."""
self.model = AutoBackend( self.model = AutoBackend(
model or self.args.model, weights=model or self.args.model,
device=select_device(self.args.device, verbose=verbose), device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn, dnn=self.args.dnn,
data=self.args.data, data=self.args.data,
fp16=self.args.half, fp16=self.args.half,
batch=self.args.batch,
fuse=True, fuse=True,
verbose=verbose, verbose=verbose,
) )

View File

@ -122,7 +122,7 @@ class BaseValidator:
else: else:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
model = AutoBackend( model = AutoBackend(
model or self.args.model, weights=model or self.args.model,
device=select_device(self.args.device, self.args.batch), device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn, dnn=self.args.dnn,
data=self.args.data, data=self.args.data,

View File

@ -86,6 +86,7 @@ class AutoBackend(nn.Module):
dnn=False, dnn=False,
data=None, data=None,
fp16=False, fp16=False,
batch=1,
fuse=True, fuse=True,
verbose=True, verbose=True,
): ):
@ -98,6 +99,7 @@ class AutoBackend(nn.Module):
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
batch (int): Batch-size to assume for inference.
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
verbose (bool): Enable verbose logging. Defaults to True. verbose (bool): Enable verbose logging. Defaults to True.
""" """
@ -204,7 +206,9 @@ class AutoBackend(nn.Module):
if batch_dim.is_static: if batch_dim.is_static:
batch_size = batch_dim.get_length() batch_size = batch_dim.get_length()
inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY"
LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch-size={batch_size} inference...")
ov_compiled_model = core.compile_model( ov_compiled_model = core.compile_model(
ov_model, ov_model,
device_name="AUTO", # AUTO selects best available device, do not modify device_name="AUTO", # AUTO selects best available device, do not modify
@ -454,7 +458,7 @@ class AutoBackend(nn.Module):
# Start async inference with userdata=i to specify the position in results list # Start async inference with userdata=i to specify the position in results list
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
async_queue.wait_all() # wait for all inference requests to complete async_queue.wait_all() # wait for all inference requests to complete
y = [list(r.values()) for r in results][0] y = np.concatenate([list(r.values())[0] for r in results])
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
y = list(self.ov_compiled_model(im).values()) y = list(self.ov_compiled_model(im).values())