mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.0.179
base Model class from nn.Module
(#4911)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c8de4fe634
commit
c17106db1f
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -257,10 +257,10 @@ jobs:
|
|||||||
activate-environment: anaconda-client-env
|
activate-environment: anaconda-client-env
|
||||||
- name: Install Libmamba
|
- name: Install Libmamba
|
||||||
run: |
|
run: |
|
||||||
# conda install conda-libmamba-solver
|
|
||||||
conda config --set solver libmamba
|
conda config --set solver libmamba
|
||||||
- name: Install Ultralytics package from conda-forge
|
- name: Install Ultralytics package from conda-forge
|
||||||
run: |
|
run: |
|
||||||
|
conda install pytorch torchvision cpuonly -c pytorch
|
||||||
conda install -c conda-forge ultralytics
|
conda install -c conda-forge ultralytics
|
||||||
- name: Install pip packages
|
- name: Install pip packages
|
||||||
run: |
|
run: |
|
||||||
|
@ -18,8 +18,9 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt .
|
|||||||
# Install conda packages
|
# Install conda packages
|
||||||
# mkl required to fix 'OSError: libmkl_intel_lp64.so.2: cannot open shared object file: No such file or directory'
|
# mkl required to fix 'OSError: libmkl_intel_lp64.so.2: cannot open shared object file: No such file or directory'
|
||||||
RUN conda config --set solver libmamba && \
|
RUN conda config --set solver libmamba && \
|
||||||
|
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia && \
|
||||||
conda install -c conda-forge ultralytics mkl
|
conda install -c conda-forge ultralytics mkl
|
||||||
# conda install -c pytorch -c nvidia -c conda-forge pytorch torchvision pytorch-cuda=11.8 ultralytics
|
# conda install -c pytorch -c nvidia -c conda-forge pytorch torchvision pytorch-cuda=11.8 ultralytics mkl
|
||||||
|
|
||||||
|
|
||||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||||
|
@ -39,6 +39,19 @@ def pytest_runtest_setup(item):
|
|||||||
pytest.skip('skip slow tests unless --slow is set')
|
pytest.skip('skip slow tests unless --slow is set')
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
"""
|
||||||
|
Modify the list of test items to remove tests marked as slow if the --slow option is not provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (pytest.config.Config): The pytest config object.
|
||||||
|
items (list): List of test items to be executed.
|
||||||
|
"""
|
||||||
|
if not config.getoption('--slow'):
|
||||||
|
# Remove the item entirely from the list of test items if it's marked as 'slow'
|
||||||
|
items[:] = [item for item in items if 'slow' not in item.keywords]
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionstart(session):
|
def pytest_sessionstart(session):
|
||||||
"""
|
"""
|
||||||
Initialize session configurations for pytest.
|
Initialize session configurations for pytest.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.178'
|
__version__ = '8.0.179'
|
||||||
|
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
|
@ -8,15 +8,14 @@ from typing import Union
|
|||||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load
|
||||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||||
from ultralytics.utils.torch_utils import smart_inference_mode
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model(nn.Module):
|
||||||
"""
|
"""
|
||||||
A base model class to unify apis for all the models.
|
A base class to unify APIs for all models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str, Path): Path to the model file to load or create.
|
model (str, Path): Path to the model file to load or create.
|
||||||
@ -63,6 +62,7 @@ class Model:
|
|||||||
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.callbacks = callbacks.get_default_callbacks()
|
self.callbacks = callbacks.get_default_callbacks()
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
self.model = None # model object
|
self.model = None # model object
|
||||||
@ -116,13 +116,12 @@ class Model:
|
|||||||
cfg_dict = yaml_model_load(cfg)
|
cfg_dict = yaml_model_load(cfg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.task = task or guess_model_task(cfg_dict)
|
self.task = task or guess_model_task(cfg_dict)
|
||||||
self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||||
self.overrides['model'] = self.cfg
|
self.overrides['model'] = self.cfg
|
||||||
self.overrides['task'] = self.task
|
self.overrides['task'] = self.task
|
||||||
|
|
||||||
# Below added to allow export from YAMLs
|
# Below added to allow export from YAMLs
|
||||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
|
||||||
self.model.task = self.task
|
self.model.task = self.task
|
||||||
|
|
||||||
def _load(self, weights: str, task=None):
|
def _load(self, weights: str, task=None):
|
||||||
@ -154,12 +153,13 @@ class Model:
|
|||||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
||||||
pt_module = isinstance(self.model, nn.Module)
|
pt_module = isinstance(self.model, nn.Module)
|
||||||
if not (pt_module or pt_str):
|
if not (pt_module or pt_str):
|
||||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
raise TypeError(
|
||||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
|
||||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
|
||||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
|
||||||
|
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
|
||||||
|
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'")
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def reset_weights(self):
|
def reset_weights(self):
|
||||||
"""
|
"""
|
||||||
Resets the model modules parameters to randomly initialized values, losing all training information.
|
Resets the model modules parameters to randomly initialized values, losing all training information.
|
||||||
@ -172,7 +172,6 @@ class Model:
|
|||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def load(self, weights='yolov8n.pt'):
|
def load(self, weights='yolov8n.pt'):
|
||||||
"""
|
"""
|
||||||
Transfers parameters with matching names and shapes from 'weights' to model.
|
Transfers parameters with matching names and shapes from 'weights' to model.
|
||||||
@ -199,7 +198,6 @@ class Model:
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
self.model.fuse()
|
self.model.fuse()
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Perform prediction using the YOLO model.
|
Perform prediction using the YOLO model.
|
||||||
@ -227,7 +225,7 @@ class Model:
|
|||||||
prompts = args.pop('prompts', None) # for SAM-type models
|
prompts = args.pop('prompts', None) # for SAM-type models
|
||||||
|
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
|
self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
|
||||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, args)
|
self.predictor.args = get_cfg(self.predictor.args, args)
|
||||||
@ -258,7 +256,6 @@ class Model:
|
|||||||
kwargs['mode'] = 'track'
|
kwargs['mode'] = 'track'
|
||||||
return self.predict(source=source, stream=stream, **kwargs)
|
return self.predict(source=source, stream=stream, **kwargs)
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def val(self, validator=None, **kwargs):
|
def val(self, validator=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Validate a model on a given dataset.
|
Validate a model on a given dataset.
|
||||||
@ -271,12 +268,11 @@ class Model:
|
|||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
||||||
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
|
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
|
||||||
|
|
||||||
validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
self.metrics = validator.metrics
|
self.metrics = validator.metrics
|
||||||
return validator.metrics
|
return validator.metrics
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def benchmark(self, **kwargs):
|
def benchmark(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Benchmark a model on all export formats.
|
Benchmark a model on all export formats.
|
||||||
@ -333,7 +329,7 @@ class Model:
|
|||||||
if args.get('resume'):
|
if args.get('resume'):
|
||||||
args['resume'] = self.ckpt_path
|
args['resume'] = self.ckpt_path
|
||||||
|
|
||||||
self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
|
self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
|
||||||
if not args.get('resume'): # manually set model only if not resuming
|
if not args.get('resume'): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
@ -365,15 +361,12 @@ class Model:
|
|||||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
||||||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||||
|
|
||||||
def to(self, device):
|
def _apply(self, fn):
|
||||||
"""
|
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
|
||||||
Sends the model to the given device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (str): device
|
|
||||||
"""
|
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
self.model.to(device)
|
self = super()._apply(fn) # noqa
|
||||||
|
self.predictor = None # reset predictor as device may have changed
|
||||||
|
self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0'
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -410,12 +403,12 @@ class Model:
|
|||||||
for event in callbacks.default_callbacks.keys():
|
for event in callbacks.default_callbacks.keys():
|
||||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
# def __getattr__(self, attr):
|
||||||
"""Raises error if object has no requested attribute."""
|
# """Raises error if object has no requested attribute."""
|
||||||
name = self.__class__.__name__
|
# name = self.__class__.__name__
|
||||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
|
|
||||||
def smart_load(self, key):
|
def _smart_load(self, key):
|
||||||
"""Load model/trainer/validator/predictor."""
|
"""Load model/trainer/validator/predictor."""
|
||||||
try:
|
try:
|
||||||
return self.task_map[self.task][key]
|
return self.task_map[self.task][key]
|
||||||
|
@ -100,10 +100,10 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
|||||||
# Export
|
# Export
|
||||||
if format == '-':
|
if format == '-':
|
||||||
filename = model.ckpt_path or model.cfg
|
filename = model.ckpt_path or model.cfg
|
||||||
export = model # PyTorch format
|
exported_model = model # PyTorch format
|
||||||
else:
|
else:
|
||||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
||||||
export = YOLO(filename, task=model.task)
|
exported_model = YOLO(filename, task=model.task)
|
||||||
assert suffix in str(filename), 'export failed'
|
assert suffix in str(filename), 'export failed'
|
||||||
emoji = '❎' # indicates export succeeded
|
emoji = '❎' # indicates export succeeded
|
||||||
|
|
||||||
@ -111,19 +111,19 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
|||||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||||
export.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
|
exported_model.predict(ASSETS / 'bus.jpg', imgsz=imgsz, device=device, half=half)
|
||||||
|
|
||||||
# Validate
|
# Validate
|
||||||
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||||
results = export.val(data=data,
|
results = exported_model.val(data=data,
|
||||||
batch=1,
|
batch=1,
|
||||||
imgsz=imgsz,
|
imgsz=imgsz,
|
||||||
plots=False,
|
plots=False,
|
||||||
device=device,
|
device=device,
|
||||||
half=half,
|
half=half,
|
||||||
int8=int8,
|
int8=int8,
|
||||||
verbose=False)
|
verbose=False)
|
||||||
metric, speed = results.results_dict[key], results.speed['inference']
|
metric, speed = results.results_dict[key], results.speed['inference']
|
||||||
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -16,7 +16,7 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
||||||
from ultralytics.utils.checks import check_version
|
from ultralytics.utils.checks import check_version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -60,13 +60,48 @@ def get_cpu_info():
|
|||||||
|
|
||||||
|
|
||||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||||
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
"""
|
||||||
|
Selects the appropriate PyTorch device based on the provided arguments.
|
||||||
|
|
||||||
|
The function takes a string specifying the device or a torch.device object and returns a torch.device object
|
||||||
|
representing the selected device. The function also validates the number of available devices and raises an
|
||||||
|
exception if the requested device(s) are not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str | torch.device, optional): Device string or torch.device object.
|
||||||
|
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
|
||||||
|
the first available GPU, or CPU if no GPU is available.
|
||||||
|
batch (int, optional): Batch size being used in your model. Defaults to 0.
|
||||||
|
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
|
||||||
|
verbose (bool, optional): If True, logs the device information. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.device: Selected device.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
|
||||||
|
devices when using multiple GPUs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> select_device('cuda:0')
|
||||||
|
device(type='cuda', index=0)
|
||||||
|
|
||||||
|
>>> select_device('cpu')
|
||||||
|
device(type='cpu')
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
return device
|
||||||
|
|
||||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||||
device = str(device).lower()
|
device = str(device).lower()
|
||||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||||
cpu = device == 'cpu'
|
cpu = device == 'cpu'
|
||||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS)
|
||||||
if cpu or mps:
|
if cpu or mps:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||||
elif device: # non-cpu device requested
|
elif device: # non-cpu device requested
|
||||||
@ -105,7 +140,7 @@ def select_device(device='', batch=0, newline=False, verbose=True):
|
|||||||
s += f'CPU ({get_cpu_info()})\n'
|
s += f'CPU ({get_cpu_info()})\n'
|
||||||
arg = 'cpu'
|
arg = 'cpu'
|
||||||
|
|
||||||
if verbose and RANK == -1:
|
if verbose:
|
||||||
LOGGER.info(s if newline else s.rstrip())
|
LOGGER.info(s if newline else s.rstrip())
|
||||||
return torch.device(arg)
|
return torch.device(arg)
|
||||||
|
|
||||||
@ -204,12 +239,15 @@ def model_info_for_loggers(trainer):
|
|||||||
"""
|
"""
|
||||||
Return model info dict with useful model information.
|
Return model info dict with useful model information.
|
||||||
|
|
||||||
Example for YOLOv8n:
|
Example:
|
||||||
{'model/parameters': 3151904,
|
YOLOv8n info for loggers
|
||||||
'model/GFLOPs': 8.746,
|
```python
|
||||||
'model/speed_ONNX(ms)': 41.244,
|
results = {'model/parameters': 3151904,
|
||||||
'model/speed_TensorRT(ms)': 3.211,
|
'model/GFLOPs': 8.746,
|
||||||
'model/speed_PyTorch(ms)': 18.755}
|
'model/speed_ONNX(ms)': 41.244,
|
||||||
|
'model/speed_TensorRT(ms)': 3.211,
|
||||||
|
'model/speed_PyTorch(ms)': 18.755}
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||||
from ultralytics.utils.benchmarks import ProfileModels
|
from ultralytics.utils.benchmarks import ProfileModels
|
||||||
|
Loading…
x
Reference in New Issue
Block a user