mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
Cleanup argument handling in Model
class (#4614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastian Stapf <42514241+Wiqzard@users.noreply.github.com>
This commit is contained in:
parent
53b4f8c713
commit
7e99804263
@ -17,6 +17,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_
|
||||
## ::: ultralytics.cfg.get_cfg
|
||||
<br><br>
|
||||
|
||||
---
|
||||
## ::: ultralytics.cfg.get_save_dir
|
||||
<br><br>
|
||||
|
||||
---
|
||||
## ::: ultralytics.cfg._handle_deprecation
|
||||
<br><br>
|
||||
|
@ -182,9 +182,11 @@ def test_export_openvino():
|
||||
|
||||
def test_export_coreml():
|
||||
if not WINDOWS: # RuntimeError: BlobWriter not loaded with coremltools 7.0 on windows
|
||||
f = YOLO(MODEL).export(format='coreml', nms=True)
|
||||
if MACOS:
|
||||
YOLO(f)(SOURCE) # model prediction only supported on macOS
|
||||
f = YOLO(MODEL).export(format='coreml')
|
||||
YOLO(f)(SOURCE) # model prediction only supported on macOS for nms=False models
|
||||
else:
|
||||
YOLO(MODEL).export(format='coreml', nms=True)
|
||||
|
||||
|
||||
def test_export_tflite(enabled=False):
|
||||
|
@ -5,12 +5,10 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis,
|
||||
yaml_load)
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||
from ultralytics.utils.torch_utils import smart_inference_mode
|
||||
@ -118,9 +116,9 @@ class Model:
|
||||
cfg_dict = yaml_model_load(cfg)
|
||||
self.cfg = cfg
|
||||
self.task = task or guess_model_task(cfg_dict)
|
||||
model = model or self.smart_load('model')
|
||||
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
self.overrides['task'] = self.task
|
||||
|
||||
# Below added to allow export from YAMLs
|
||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||
@ -220,28 +218,22 @@ class Model:
|
||||
if source is None:
|
||||
source = ASSETS
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
||||
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
||||
# Check prompts for SAM/FastSAM
|
||||
prompts = kwargs.pop('prompts', None)
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
if not is_cli:
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
|
||||
custom = {'conf': 0.25, 'save': is_cli} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
|
||||
prompts = args.pop('prompts', None) # for SAM-type models
|
||||
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
predictor = predictor or self.smart_load('predictor')
|
||||
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
if 'project' in overrides or 'name' in overrides:
|
||||
self.predictor.args = get_cfg(self.predictor.args, args)
|
||||
if 'project' in args or 'name' in args:
|
||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||
# Set prompts for SAM/FastSAM
|
||||
if len and hasattr(self.predictor, 'set_prompts'):
|
||||
if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
|
||||
self.predictor.set_prompts(prompts)
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
@ -257,46 +249,31 @@ class Model:
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.engine.results.Results]): The tracking results.
|
||||
|
||||
"""
|
||||
if not hasattr(self.predictor, 'trackers'):
|
||||
from ultralytics.trackers import register_tracker
|
||||
register_tracker(self, persist)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['conf'] = kwargs.get('conf') or 0.1
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, validator=None, **kwargs):
|
||||
def val(self, validator=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset.
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
validator (BaseValidator): Customized validator.
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides['rect'] = True # rect batches as default
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'val'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
if 'task' in overrides:
|
||||
self.task = args.task
|
||||
else:
|
||||
args.task = self.task
|
||||
validator = validator or self.smart_load('validator')
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
custom = {'rect': True} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
||||
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
|
||||
|
||||
validator = validator(args=args, _callbacks=self.callbacks)
|
||||
validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -309,17 +286,16 @@ class Model:
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'benchmark'
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
|
||||
custom = {'verbose': False} # method defaults
|
||||
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
|
||||
return benchmark(
|
||||
model=self,
|
||||
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
|
||||
imgsz=overrides['imgsz'],
|
||||
half=overrides['half'],
|
||||
int8=overrides['int8'],
|
||||
device=overrides['device'],
|
||||
imgsz=args['imgsz'],
|
||||
half=args['half'],
|
||||
int8=args['int8'],
|
||||
device=args['device'],
|
||||
verbose=kwargs.get('verbose'))
|
||||
|
||||
def export(self, **kwargs):
|
||||
@ -327,22 +303,13 @@ class Model:
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
**kwargs : Any other args accepted by the Exporter. To see all args check 'configuration' section in docs.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if 'batch' not in kwargs:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
if 'data' not in kwargs:
|
||||
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
||||
if 'verbose' not in kwargs:
|
||||
overrides['verbose'] = False
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
from .exporter import Exporter
|
||||
|
||||
custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, trainer=None, **kwargs):
|
||||
@ -359,20 +326,15 @@ class Model:
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
check_pip_update_available()
|
||||
overrides = self.overrides.copy()
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
self.task = overrides.get('task') or self.task
|
||||
trainer = trainer or self.smart_load('trainer')
|
||||
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||
custom = {'data': TASK2DATA[self.task]} # method defaults
|
||||
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
||||
if args.get('resume'):
|
||||
args['resume'] = self.ckpt_path
|
||||
|
||||
self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
|
||||
if not args.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
@ -455,7 +417,7 @@ class Model:
|
||||
name = self.__class__.__name__
|
||||
mode = inspect.stack()[1][3] # get the function name.
|
||||
raise NotImplementedError(
|
||||
emojis(f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')) from e
|
||||
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
|
@ -175,8 +175,13 @@ class BaseTrainer:
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||
self.args.rect = False
|
||||
if self.args.batch == -1:
|
||||
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||
"default 'batch=16'")
|
||||
self.args.batch = 16
|
||||
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
@ -186,6 +191,7 @@ class BaseTrainer:
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
@ -248,9 +254,6 @@ class BaseTrainer:
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||
else:
|
||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
|
@ -18,7 +18,6 @@ from PIL import Image
|
||||
from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def check_class_names(names):
|
||||
@ -363,9 +362,13 @@ class AutoBackend(nn.Module):
|
||||
# im = im.resize((192, 320), Image.BILINEAR)
|
||||
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
||||
if 'confidence' in y:
|
||||
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
||||
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
||||
raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
|
||||
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
|
||||
# TODO: CoreML NMS inference handling
|
||||
# from ultralytics.utils.ops import xywh2xyxy
|
||||
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||
# conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
|
||||
# y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
||||
elif len(y) == 1: # classification model
|
||||
y = list(y.values())
|
||||
elif len(y) == 2: # segmentation model
|
||||
|
@ -355,15 +355,15 @@ class DeformableTransformerDecoder(nn.Module):
|
||||
for i, layer in enumerate(self.layers):
|
||||
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
||||
|
||||
# refine bboxes, (bs, num_queries+num_denoising, 4)
|
||||
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
|
||||
bbox = bbox_head[i](output)
|
||||
refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
|
||||
|
||||
if self.training:
|
||||
dec_cls.append(score_head[i](output))
|
||||
if i == 0:
|
||||
dec_bboxes.append(refined_bbox)
|
||||
else:
|
||||
dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
|
||||
dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
|
||||
elif i == self.eval_idx:
|
||||
dec_cls.append(score_head[i](output))
|
||||
dec_bboxes.append(refined_bbox)
|
||||
|
Loading…
x
Reference in New Issue
Block a user