mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
[WIP] Model interface (#68)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
e6737f1207
commit
7ae45c6cc4
@ -1,3 +1,5 @@
|
|||||||
|
from ultralytics.yolo import v8
|
||||||
|
|
||||||
from .engine.model import YOLO
|
from .engine.model import YOLO
|
||||||
from .engine.trainer import BaseTrainer
|
from .engine.trainer import BaseTrainer
|
||||||
from .engine.validator import BaseValidator
|
from .engine.validator import BaseValidator
|
||||||
|
@ -1,55 +1,45 @@
|
|||||||
"""
|
"""
|
||||||
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
|
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import attempt_load_weights
|
||||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||||
|
|
||||||
# map head: [model, trainer]
|
# map head: [model, trainer]
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
"classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
|
"classify": [ClassificationModel, 'yolo.VERSION.classify.ClassificationTrainer'],
|
||||||
"detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp
|
"detect": [DetectionModel, 'yolo.VERSION.detect.DetectionTrainer'],
|
||||||
"segment": []}
|
"segment": [SegmentationModel, 'yolo.VERSION.segment.SegmentationTrainer']}
|
||||||
|
|
||||||
|
|
||||||
class YOLO:
|
class YOLO:
|
||||||
|
|
||||||
def __init__(self, task=None, version=8) -> None:
|
def __init__(self, version=8) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
self.ModelClass = None
|
self.ModelClass = None
|
||||||
self.TrainerClass = None
|
self.TrainerClass = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.pretrained_weights = None
|
self.trainer = None
|
||||||
if task:
|
self.task = None
|
||||||
if task.lower() not in MODEL_MAP:
|
self.ckpt = None
|
||||||
raise Exception(f"Unsupported task {task}. The supported tasks are: \n {MODEL_MAP.keys()}")
|
|
||||||
self.ModelClass, self.TrainerClass = MODEL_MAP[task]
|
|
||||||
self.TrainerClass = eval(self.trainer.replace("VERSION", f"v{self.version}"))
|
|
||||||
|
|
||||||
def new(self, cfg: str):
|
def new(self, cfg: str):
|
||||||
cfg = check_yaml(cfg) # check YAML
|
cfg = check_yaml(cfg) # check YAML
|
||||||
if self.model:
|
with open(cfg, encoding='ascii', errors='ignore') as f:
|
||||||
self.model = self.model(cfg)
|
cfg = yaml.safe_load(f) # model dict
|
||||||
else:
|
self.ModelClass, self.TrainerClass, self.task = self._guess_model_trainer_and_task(cfg["head"][-1][-2])
|
||||||
with open(cfg, encoding='ascii', errors='ignore') as f:
|
self.model = self.ModelClass(cfg) # initialize
|
||||||
cfg = yaml.safe_load(f) # model dict
|
|
||||||
self.ModelClass, self.TrainerClass = self._get_model_and_trainer(cfg["head"])
|
|
||||||
self.model = self.ModelClass(cfg) # initialize
|
|
||||||
|
|
||||||
def load(self, weights, autodownload=True):
|
def load(self, weights):
|
||||||
if not isinstance(self.pretrained_weights, type(None)):
|
self.ckpt = torch.load(weights, map_location="cpu")
|
||||||
LOGGER.info("Overwriting weights")
|
self.task = self.ckpt["train_args"]["task"]
|
||||||
# TODO: weights = smart_file_loader(weights)
|
_, trainer_class_literal = MODEL_MAP[self.task]
|
||||||
if self.model:
|
self.TrainerClass = eval(trainer_class_literal.replace("VERSION", f"v{self.version}"))
|
||||||
self.model.load(weights)
|
self.model = attempt_load_weights(weights)
|
||||||
LOGGER.info("Checkpoint loaded successfully")
|
|
||||||
else:
|
|
||||||
self.model = get_model(weights)
|
|
||||||
self.ModelClass, self.TrainerClass = self._guess_model_and_trainer(list(self.model.named_children()))
|
|
||||||
self.pretrained_weights = weights
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
@ -61,16 +51,31 @@ class YOLO:
|
|||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
if 'data' not in kwargs:
|
if 'data' not in kwargs:
|
||||||
raise Exception("data is required to train")
|
raise Exception("data is required to train")
|
||||||
if not self.model:
|
if not self.model and not self.ckpt:
|
||||||
raise Exception("model not initialized. Use .new() or .load()")
|
raise Exception("model not initialized. Use .new() or .load()")
|
||||||
# kwargs["model"] = self.model
|
|
||||||
trainer = self.TrainerClass(overrides=kwargs)
|
|
||||||
trainer.model = self.model
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
def _guess_model_and_trainer(self, cfg):
|
kwargs["task"] = self.task
|
||||||
|
kwargs["mode"] = "train"
|
||||||
|
self.trainer = self.TrainerClass(overrides=kwargs)
|
||||||
|
# load pre-trained weights if found, else use the loaded model
|
||||||
|
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
||||||
|
self.trainer.train()
|
||||||
|
|
||||||
|
def resume(self, task=None, model=None):
|
||||||
|
if not task:
|
||||||
|
raise Exception(
|
||||||
|
"pass the task type and/or model(optional) from which you want to resume: `model.resume(task="
|
||||||
|
")`")
|
||||||
|
if task.lower() not in MODEL_MAP:
|
||||||
|
raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
||||||
|
_, trainer_class_literal = MODEL_MAP[task.lower()]
|
||||||
|
self.TrainerClass = eval(trainer_class_literal.replace("VERSION", f"v{self.version}"))
|
||||||
|
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True})
|
||||||
|
self.trainer.train()
|
||||||
|
|
||||||
|
def _guess_model_trainer_and_task(self, head):
|
||||||
# TODO: warn
|
# TODO: warn
|
||||||
head = cfg[-1][-2]
|
task = None
|
||||||
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
||||||
task = "classify"
|
task = "classify"
|
||||||
if head.lower() in ["detect"]:
|
if head.lower() in ["detect"]:
|
||||||
@ -81,11 +86,9 @@ class YOLO:
|
|||||||
# warning: eval is unsafe. Use with caution
|
# warning: eval is unsafe. Use with caution
|
||||||
trainer_class = eval(trainer_class.replace("VERSION", f"v{self.version}"))
|
trainer_class = eval(trainer_class.replace("VERSION", f"v{self.version}"))
|
||||||
|
|
||||||
return model_class, trainer_class
|
return model_class, trainer_class, task
|
||||||
|
|
||||||
|
def __call__(self, imgs):
|
||||||
if __name__ == "__main__":
|
if not self.model:
|
||||||
model = YOLO()
|
LOGGER.info("model not initialized!")
|
||||||
# model.new("assets/dummy_model.yaml")
|
return self.model(imgs)
|
||||||
model.load("yolov5n-cls.pt")
|
|
||||||
model.train(data="imagenette160", epochs=1, lr0=0.01)
|
|
||||||
|
@ -8,7 +8,6 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -28,7 +27,6 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
|||||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
from ultralytics.yolo.utils.configs import get_config
|
from ultralytics.yolo.utils.configs import get_config
|
||||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
@ -63,6 +61,7 @@ class BaseTrainer:
|
|||||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||||
|
|
||||||
# Model and Dataloaders.
|
# Model and Dataloaders.
|
||||||
|
self.model = self.args.model
|
||||||
self.data = self.args.data
|
self.data = self.args.data
|
||||||
if self.data.endswith(".yaml"):
|
if self.data.endswith(".yaml"):
|
||||||
self.data = check_dataset_yaml(self.data)
|
self.data = check_dataset_yaml(self.data)
|
||||||
@ -125,6 +124,7 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
# model
|
# model
|
||||||
ckpt = self.setup_model()
|
ckpt = self.setup_model()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self.model = DDP(self.model, device_ids=[rank])
|
self.model = DDP(self.model, device_ids=[rank])
|
||||||
@ -288,13 +288,16 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
load/create/download model for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
model = self.args.model
|
if isinstance(self.model, torch.nn.Module): # if loaded model is passed
|
||||||
|
return
|
||||||
|
# We should improve the code flow here. This function looks hacky
|
||||||
|
model = self.model
|
||||||
pretrained = not (str(model).endswith(".yaml"))
|
pretrained = not (str(model).endswith(".yaml"))
|
||||||
# config
|
# config
|
||||||
if not pretrained:
|
if not pretrained:
|
||||||
model = check_file(model)
|
model = check_file(model)
|
||||||
ckpt = self.load_ckpt(model) if pretrained else None
|
ckpt = self.load_ckpt(model) if pretrained else None
|
||||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
|
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def load_ckpt(self, ckpt):
|
def load_ckpt(self, ckpt):
|
||||||
@ -402,7 +405,7 @@ class BaseTrainer:
|
|||||||
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
|
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
|
||||||
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
||||||
if args_yaml.is_file():
|
if args_yaml.is_file():
|
||||||
args = self._get_config(args_yaml) # replace
|
args = get_config(args_yaml) # replace
|
||||||
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
@ -424,8 +427,7 @@ class BaseTrainer:
|
|||||||
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
||||||
if self.epochs < start_epoch:
|
if self.epochs < start_epoch:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
||||||
)
|
|
||||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||||
self.best_fitness = best_fitness
|
self.best_fitness = best_fitness
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
@ -460,9 +462,3 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
|||||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
# Dummy validator
|
|
||||||
def val(trainer: BaseTrainer):
|
|
||||||
trainer.console.info("validating")
|
|
||||||
return {"metric_1": 0.1, "metric_2": 0.2, "fitness": 1}
|
|
||||||
|
@ -13,8 +13,10 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights):
|
def load_model(self, model_cfg=None, weights=None):
|
||||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||||
|
if isinstance(weights, dict): # yolo ckpt
|
||||||
|
weights = weights["model"]
|
||||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||||
model = weights
|
model = weights
|
||||||
else:
|
else:
|
||||||
|
@ -15,7 +15,7 @@ from .val import DetectionValidator
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class DetectionTrainer(SegmentationTrainer):
|
class DetectionTrainer(SegmentationTrainer):
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights):
|
def load_model(self, model_cfg=None, weights=None):
|
||||||
model = DetectionModel(model_cfg or weights["model"].yaml,
|
model = DetectionModel(model_cfg or weights["model"].yaml,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=self.data["nc"],
|
nc=self.data["nc"],
|
||||||
|
@ -26,7 +26,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights):
|
def load_model(self, model_cfg=None, weights=None):
|
||||||
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=self.data["nc"],
|
nc=self.data["nc"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user