mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
add resuming (#63)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
de3e6ca54d
commit
fbeeb5d1e1
@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||
@ -38,6 +37,7 @@ class BaseTrainer:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = self._get_config(config, overrides)
|
||||
self.check_resume()
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
self.console = LOGGER
|
||||
@ -50,6 +50,7 @@ class BaseTrainer:
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
self.batch_size = self.args.batch_size
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Save run settings
|
||||
@ -66,8 +67,6 @@ class BaseTrainer:
|
||||
else:
|
||||
self.data = check_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.model:
|
||||
self.model = self.get_model(self.args.model)
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
@ -136,15 +135,17 @@ class BaseTrainer:
|
||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
"""
|
||||
# Optimizer
|
||||
# model
|
||||
ckpt = self.setup_model()
|
||||
self.set_model_attributes()
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
self.optimizer = build_optimizer(model=self.model,
|
||||
@ -158,6 +159,8 @@ class BaseTrainer:
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
|
||||
# dataloaders
|
||||
batch_size = self.batch_size // world_size
|
||||
@ -174,20 +177,18 @@ class BaseTrainer:
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
self._setup_ddp(rank, world_size)
|
||||
else:
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
self.trigger_callbacks("before_train")
|
||||
self._setup_train(rank, world_size)
|
||||
self.trigger_callbacks("before_train")
|
||||
|
||||
self.epoch = 0
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
for epoch in range(self.epochs):
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.trigger_callbacks("on_epoch_start")
|
||||
self.model.train()
|
||||
if rank != -1:
|
||||
@ -257,11 +258,10 @@ class BaseTrainer:
|
||||
self.save_metrics(metrics=log_vals)
|
||||
|
||||
# save model
|
||||
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
|
||||
if (not self.args.nosave) or (epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.trigger_callbacks('on_model_save')
|
||||
|
||||
self.epoch += 1
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
@ -301,17 +301,21 @@ class BaseTrainer:
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def get_model(self, model: Union[str, Path]):
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
pretrained = True
|
||||
if str(model).endswith(".yaml"):
|
||||
model = self.args.model
|
||||
pretrained = not (str(model).endswith(".yaml"))
|
||||
# config
|
||||
if not pretrained:
|
||||
model = check_file(model)
|
||||
pretrained = False
|
||||
return self.load_model(model_cfg=None if pretrained else model,
|
||||
weights=get_model(model) if pretrained else None,
|
||||
data=self.data) # model
|
||||
ckpt = self.load_ckpt(model) if pretrained else None
|
||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
|
||||
return ckpt
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
return torch.load(ckpt, map_location='cpu')
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
@ -350,7 +354,7 @@ class BaseTrainer:
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
def load_model(self, model_cfg, weights):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
@ -409,6 +413,40 @@ class BaseTrainer:
|
||||
if f is self.best:
|
||||
self.console.info(f'\nValidating {f}...')
|
||||
|
||||
def check_resume(self):
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
|
||||
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
||||
if args_yaml.is_file():
|
||||
args = self._get_config(args_yaml) # replace
|
||||
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
||||
self.args = args
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if ckpt['optimizer'] is not None:
|
||||
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||
best_fitness = ckpt['best_fitness']
|
||||
if self.ema and ckpt.get('ema'):
|
||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt['updates']
|
||||
if self.args.resume:
|
||||
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
||||
LOGGER.info(
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||
)
|
||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
|
||||
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
|
||||
|
@ -33,6 +33,7 @@ overlap_mask: True # masks overlap
|
||||
mask_ratio: 4 # mask downsample ratio
|
||||
# Classification
|
||||
dropout: False # use dropout
|
||||
resume: False
|
||||
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
|
@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -74,3 +75,9 @@ def file_date(path=__file__):
|
||||
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||
|
||||
|
||||
@ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
def load_model(self, model_cfg, weights):
|
||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||
model = weights
|
||||
else:
|
||||
model = ClassificationModel(model_cfg, weights, data["nc"])
|
||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||
model = ClassificationModel(model_cfg, weights, self.data["nc"])
|
||||
ClassificationModel.reshape_outputs(model, self.data["nc"])
|
||||
for m in model.modules():
|
||||
if not weights and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
@ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer):
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def load_ckpt(self, ckpt):
|
||||
return get_model(ckpt)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.img_size,
|
||||
@ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer):
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
||||
return loss, loss
|
||||
|
||||
def check_resume(self):
|
||||
pass
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
pass
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
|
@ -15,10 +15,10 @@ from .val import DetectionValidator
|
||||
# BaseTrainer python usage
|
||||
class DetectionTrainer(SegmentationTrainer):
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
def load_model(self, model_cfg, weights):
|
||||
model = DetectionModel(model_cfg or weights["model"].yaml,
|
||||
ch=3,
|
||||
nc=data["nc"],
|
||||
nc=self.data["nc"],
|
||||
anchors=self.args.get("anchors"))
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
@ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
return batch
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
def load_model(self, model_cfg, weights):
|
||||
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||
ch=3,
|
||||
nc=data["nc"],
|
||||
nc=self.data["nc"],
|
||||
anchors=self.args.get("anchors"))
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
@ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator):
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
paths=paths,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user