mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
Avoid CUDA round-trip for relevant export formats (#3727)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c5991d7cd8
commit
135a10f1fa
@ -83,16 +83,23 @@ class AutoBackend(nn.Module):
|
|||||||
nn_module = isinstance(weights, torch.nn.Module)
|
nn_module = isinstance(weights, torch.nn.Module)
|
||||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
|
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
|
||||||
self._model_type(w)
|
self._model_type(w)
|
||||||
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
|
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
||||||
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
||||||
stride = 32 # default stride
|
stride = 32 # default stride
|
||||||
model, metadata = None, None
|
model, metadata = None, None
|
||||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
|
||||||
if not (pt or triton or nn_module):
|
|
||||||
w = attempt_download_asset(w) # download if not local
|
|
||||||
|
|
||||||
# NOTE: special case: in-memory pytorch model
|
# Set device
|
||||||
if nn_module:
|
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||||
|
if cuda and not any([nn_module, pt, jit, engine]): # GPU dataloader formats
|
||||||
|
device = torch.device('cpu')
|
||||||
|
cuda = False
|
||||||
|
|
||||||
|
# Download if not local
|
||||||
|
if not (pt or triton or nn_module):
|
||||||
|
w = attempt_download_asset(w)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
if nn_module: # in-memory PyTorch model
|
||||||
model = weights.to(device)
|
model = weights.to(device)
|
||||||
model = model.fuse(verbose=verbose) if fuse else model
|
model = model.fuse(verbose=verbose) if fuse else model
|
||||||
if hasattr(model, 'kpt_shape'):
|
if hasattr(model, 'kpt_shape'):
|
||||||
@ -269,14 +276,13 @@ class AutoBackend(nn.Module):
|
|||||||
net.load_model(str(w.with_suffix('.bin')))
|
net.load_model(str(w.with_suffix('.bin')))
|
||||||
metadata = w.parent / 'metadata.yaml'
|
metadata = w.parent / 'metadata.yaml'
|
||||||
elif triton: # NVIDIA Triton Inference Server
|
elif triton: # NVIDIA Triton Inference Server
|
||||||
LOGGER.info('Triton Inference Server not supported...')
|
"""TODO
|
||||||
'''
|
|
||||||
TODO:
|
|
||||||
check_requirements('tritonclient[all]')
|
check_requirements('tritonclient[all]')
|
||||||
from utils.triton import TritonRemoteModel
|
from utils.triton import TritonRemoteModel
|
||||||
model = TritonRemoteModel(url=w)
|
model = TritonRemoteModel(url=w)
|
||||||
nhwc = model.runtime.startswith("tensorflow")
|
nhwc = model.runtime.startswith("tensorflow")
|
||||||
'''
|
"""
|
||||||
|
raise NotImplementedError('Triton Inference Server is not currently supported.')
|
||||||
else:
|
else:
|
||||||
from ultralytics.yolo.engine.exporter import export_formats
|
from ultralytics.yolo.engine.exporter import export_formats
|
||||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||||
|
@ -18,7 +18,9 @@ from .build import build_sam
|
|||||||
|
|
||||||
class Predictor(BasePredictor):
|
class Predictor(BasePredictor):
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
# SAM needs retina_masks=True, or the results would be a mess.
|
# SAM needs retina_masks=True, or the results would be a mess.
|
||||||
@ -90,7 +92,7 @@ class Predictor(BasePredictor):
|
|||||||
of masks and H=W=256. These low resolution logits can be passed to
|
of masks and H=W=256. These low resolution logits can be passed to
|
||||||
a subsequent iteration as mask input.
|
a subsequent iteration as mask input.
|
||||||
"""
|
"""
|
||||||
if all([i is None for i in [bboxes, points, masks]]):
|
if all(i is None for i in [bboxes, points, masks]):
|
||||||
return self.generate(im, *args, **kwargs)
|
return self.generate(im, *args, **kwargs)
|
||||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||||
|
|
||||||
@ -284,7 +286,7 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
return pred_masks, pred_scores, pred_bboxes
|
return pred_masks, pred_scores, pred_bboxes
|
||||||
|
|
||||||
def setup_model(self, model):
|
def setup_model(self, model, verbose=True):
|
||||||
"""Set up YOLO model with specified thresholds and device."""
|
"""Set up YOLO model with specified thresholds and device."""
|
||||||
device = select_device(self.args.device)
|
device = select_device(self.args.device)
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -306,7 +308,7 @@ class Predictor(BasePredictor):
|
|||||||
# (N, 1, H, W), (N, 1)
|
# (N, 1, H, W), (N, 1)
|
||||||
pred_masks, pred_scores = preds[:2]
|
pred_masks, pred_scores = preds[:2]
|
||||||
pred_bboxes = preds[2] if self.segment_all else None
|
pred_bboxes = preds[2] if self.segment_all else None
|
||||||
names = dict(enumerate([str(i) for i in range(len(pred_masks))]))
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
||||||
results = []
|
results = []
|
||||||
for i, masks in enumerate([pred_masks]):
|
for i, masks in enumerate([pred_masks]):
|
||||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||||
|
@ -300,17 +300,16 @@ class BasePredictor:
|
|||||||
|
|
||||||
def setup_model(self, model, verbose=True):
|
def setup_model(self, model, verbose=True):
|
||||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||||
device = select_device(self.args.device, verbose=verbose)
|
self.model = AutoBackend(model or self.args.model,
|
||||||
model = model or self.args.model
|
device=select_device(self.args.device, verbose=verbose),
|
||||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
|
||||||
self.model = AutoBackend(model,
|
|
||||||
device=device,
|
|
||||||
dnn=self.args.dnn,
|
dnn=self.args.dnn,
|
||||||
data=self.args.data,
|
data=self.args.data,
|
||||||
fp16=self.args.half,
|
fp16=self.args.half,
|
||||||
fuse=True,
|
fuse=True,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
self.device = device
|
|
||||||
|
self.device = self.model.device # update device
|
||||||
|
self.args.half = self.model.fp16 # update half
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def show(self, p):
|
def show(self, p):
|
||||||
|
@ -109,17 +109,19 @@ class BaseValidator:
|
|||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
self.run_callbacks('on_val_start')
|
self.run_callbacks('on_val_start')
|
||||||
assert model is not None, 'Either trainer or model is needed for validation'
|
assert model is not None, 'Either trainer or model is needed for validation'
|
||||||
self.device = select_device(self.args.device, self.args.batch)
|
model = AutoBackend(model,
|
||||||
self.args.half &= self.device.type != 'cpu'
|
device=select_device(self.args.device, self.args.batch),
|
||||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
dnn=self.args.dnn,
|
||||||
|
data=self.args.data,
|
||||||
|
fp16=self.args.half)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = model.device # update device
|
||||||
|
self.args.half = model.fp16 # update half
|
||||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||||
if engine:
|
if engine:
|
||||||
self.args.batch = model.batch_size
|
self.args.batch = model.batch_size
|
||||||
else:
|
elif not pt and not jit:
|
||||||
self.device = model.device
|
|
||||||
if not pt and not jit:
|
|
||||||
self.args.batch = 1 # export.py models default to batch-size 1
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||||
|
|
||||||
|
@ -213,7 +213,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
prefix = colorstr('red', 'bold', 'requirements:')
|
prefix = colorstr('red', 'bold', 'requirements:')
|
||||||
check_python() # check python version
|
check_python() # check python version
|
||||||
check_torchvision() # check torch-torchvision compatibility
|
check_torchvision() # check torch-torchvision compatibility
|
||||||
file = None
|
|
||||||
if isinstance(requirements, Path): # requirements.txt file
|
if isinstance(requirements, Path): # requirements.txt file
|
||||||
file = requirements.resolve()
|
file = requirements.resolve()
|
||||||
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||||
@ -225,13 +224,13 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
s = '' # console string
|
s = '' # console string
|
||||||
pkgs = []
|
pkgs = []
|
||||||
for r in requirements:
|
for r in requirements:
|
||||||
rmin = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||||
try:
|
try:
|
||||||
pkg.require(rmin)
|
pkg.require(r_stripped)
|
||||||
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
||||||
try: # attempt to import (slower but more accurate)
|
try: # attempt to import (slower but more accurate)
|
||||||
import importlib
|
import importlib
|
||||||
importlib.import_module(next(pkg.parse_requirements(rmin)).name)
|
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
s += f'"{r}" '
|
s += f'"{r}" '
|
||||||
pkgs.append(r)
|
pkgs.append(r)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user