mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-06-09 09:34:24 +08:00
New guess_model_task()
function (#614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
520825c4b2
commit
59d4335664
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.18"
|
__version__ = "8.0.19"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -251,7 +251,7 @@ class ClassificationModel(BaseModel):
|
|||||||
ch=3,
|
ch=3,
|
||||||
nc=1000,
|
nc=1000,
|
||||||
cutoff=10,
|
cutoff=10,
|
||||||
verbose=True): # yaml, model, number of classes, cutoff index
|
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||||
|
|
||||||
@ -457,3 +457,53 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
ch = []
|
ch = []
|
||||||
ch.append(c2)
|
ch.append(c2)
|
||||||
return nn.Sequential(*layers), sorted(save)
|
return nn.Sequential(*layers), sorted(save)
|
||||||
|
|
||||||
|
|
||||||
|
def guess_model_task(model):
|
||||||
|
"""
|
||||||
|
Guess the task of a PyTorch model from its architecture or configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module) or (dict): PyTorch model or model configuration in YAML format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Task of the model ('detect', 'segment', 'classify').
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SyntaxError: If the task of the model could not be determined.
|
||||||
|
"""
|
||||||
|
cfg, task = None, None
|
||||||
|
if isinstance(model, dict):
|
||||||
|
cfg = model
|
||||||
|
elif isinstance(model, nn.Module): # PyTorch model
|
||||||
|
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
cfg = eval(x)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Guess from YAML dictionary
|
||||||
|
if cfg:
|
||||||
|
m = cfg["head"][-1][-2].lower() # output module name
|
||||||
|
if m in ["classify", "classifier", "cls", "fc"]:
|
||||||
|
task = "classify"
|
||||||
|
if m in ["detect"]:
|
||||||
|
task = "detect"
|
||||||
|
if m in ["segment"]:
|
||||||
|
task = "segment"
|
||||||
|
|
||||||
|
# Guess from PyTorch model
|
||||||
|
if task is None and isinstance(model, nn.Module):
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, Detect):
|
||||||
|
task = "detect"
|
||||||
|
elif isinstance(m, Segment):
|
||||||
|
task = "segment"
|
||||||
|
elif isinstance(m, Classify):
|
||||||
|
task = "classify"
|
||||||
|
|
||||||
|
# Unable to determine task from model
|
||||||
|
if task is None:
|
||||||
|
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
||||||
|
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||||
|
else:
|
||||||
|
return task
|
||||||
|
@ -66,7 +66,7 @@ import torch
|
|||||||
|
|
||||||
import ultralytics
|
import ultralytics
|
||||||
from ultralytics.nn.modules import Detect, Segment
|
from ultralytics.nn.modules import Detect, Segment
|
||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||||
from ultralytics.yolo.data.utils import check_det_dataset
|
from ultralytics.yolo.data.utils import check_det_dataset
|
||||||
@ -74,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get
|
|||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size
|
from ultralytics.yolo.utils.files import file_size
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ class Exporter:
|
|||||||
# Finish
|
# Finish
|
||||||
f = [str(x) for x in f if x] # filter out '' and None
|
f = [str(x) for x in f if x] # filter out '' and None
|
||||||
if any(f):
|
if any(f):
|
||||||
task = guess_task_from_model_yaml(model)
|
task = guess_model_task(model)
|
||||||
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
||||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics import yolo # noqa
|
from ultralytics import yolo # noqa
|
||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||||
|
guess_model_task)
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||||
|
|
||||||
# Map head to model, trainer, validator, and predictor classes
|
# Map head to model, trainer, validator, and predictor classes
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
@ -73,9 +74,9 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
cfg = check_yaml(cfg) # check YAML
|
cfg = check_yaml(cfg) # check YAML
|
||||||
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
||||||
self.task = guess_task_from_model_yaml(cfg_dict)
|
self.task = guess_model_task(cfg_dict)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||||
self._guess_ops_from_task(self.task)
|
self._assign_ops_from_task(self.task)
|
||||||
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ class YOLO:
|
|||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self._reset_ckpt_args(self.overrides)
|
self._reset_ckpt_args(self.overrides)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||||
self._guess_ops_from_task(self.task)
|
self._assign_ops_from_task(self.task)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@ -217,7 +218,7 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
def _guess_ops_from_task(self, task):
|
def _assign_ops_from_task(self, task):
|
||||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||||
# warning: eval is unsafe. Use with caution
|
# warning: eval is unsafe. Use with caution
|
||||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
"""
|
"""
|
||||||
Auto-batch utils
|
AutoBatch utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -308,23 +308,6 @@ def strip_optimizer(f='best.pt', s=''):
|
|||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
|
||||||
|
|
||||||
def guess_task_from_model_yaml(model):
|
|
||||||
try:
|
|
||||||
cfg = model if isinstance(model, dict) else model.yaml # model cfg dict
|
|
||||||
m = cfg["head"][-1][-2].lower() # output module name
|
|
||||||
task = None
|
|
||||||
if m in ["classify", "classifier", "cls", "fc"]:
|
|
||||||
task = "classify"
|
|
||||||
if m in ["detect"]:
|
|
||||||
task = "detect"
|
|
||||||
if m in ["segment"]:
|
|
||||||
task = "segment"
|
|
||||||
except Exception as e:
|
|
||||||
raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
|
|
||||||
'Valid tasks are detect, segment, classify.') from e
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
def profile(input, ops, n=10, device=None):
|
def profile(input, ops, n=10, device=None):
|
||||||
""" YOLOv8 speed/memory/FLOPs profiler
|
""" YOLOv8 speed/memory/FLOPs profiler
|
||||||
Usage:
|
Usage:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user