mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
add resuming (#63)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
de3e6ca54d
commit
fbeeb5d1e1
@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
|
|||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
@ -38,6 +37,7 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||||
self.args = self._get_config(config, overrides)
|
self.args = self._get_config(config, overrides)
|
||||||
|
self.check_resume()
|
||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
@ -50,6 +50,7 @@ class BaseTrainer:
|
|||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||||
self.batch_size = self.args.batch_size
|
self.batch_size = self.args.batch_size
|
||||||
self.epochs = self.args.epochs
|
self.epochs = self.args.epochs
|
||||||
|
self.start_epoch = 0
|
||||||
print_args(dict(self.args))
|
print_args(dict(self.args))
|
||||||
|
|
||||||
# Save run settings
|
# Save run settings
|
||||||
@ -66,8 +67,6 @@ class BaseTrainer:
|
|||||||
else:
|
else:
|
||||||
self.data = check_dataset(self.data)
|
self.data = check_dataset(self.data)
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
if self.args.model:
|
|
||||||
self.model = self.get_model(self.args.model)
|
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
# Optimization utils init
|
# Optimization utils init
|
||||||
@ -136,15 +135,17 @@ class BaseTrainer:
|
|||||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||||
|
|
||||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||||
self.model = self.model.to(self.device)
|
|
||||||
self.model = DDP(self.model, device_ids=[rank])
|
|
||||||
|
|
||||||
def _setup_train(self, rank, world_size):
|
def _setup_train(self, rank, world_size):
|
||||||
"""
|
"""
|
||||||
Builds dataloaders and optimizer on correct rank process
|
Builds dataloaders and optimizer on correct rank process
|
||||||
"""
|
"""
|
||||||
# Optimizer
|
# model
|
||||||
|
ckpt = self.setup_model()
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
|
if world_size > 1:
|
||||||
|
self.model = DDP(self.model, device_ids=[rank])
|
||||||
|
# Optimizer
|
||||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
self.optimizer = build_optimizer(model=self.model,
|
self.optimizer = build_optimizer(model=self.model,
|
||||||
@ -158,6 +159,8 @@ class BaseTrainer:
|
|||||||
else:
|
else:
|
||||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
self.resume_training(ckpt)
|
||||||
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
|
|
||||||
# dataloaders
|
# dataloaders
|
||||||
batch_size = self.batch_size // world_size
|
batch_size = self.batch_size // world_size
|
||||||
@ -174,20 +177,18 @@ class BaseTrainer:
|
|||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self._setup_ddp(rank, world_size)
|
self._setup_ddp(rank, world_size)
|
||||||
else:
|
|
||||||
self.model = self.model.to(self.device)
|
|
||||||
|
|
||||||
self.trigger_callbacks("before_train")
|
|
||||||
self._setup_train(rank, world_size)
|
self._setup_train(rank, world_size)
|
||||||
|
self.trigger_callbacks("before_train")
|
||||||
|
|
||||||
self.epoch = 0
|
|
||||||
self.epoch_time = None
|
self.epoch_time = None
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.train_time_start = time.time()
|
self.train_time_start = time.time()
|
||||||
nb = len(self.train_loader) # number of batches
|
nb = len(self.train_loader) # number of batches
|
||||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||||
last_opt_step = -1
|
last_opt_step = -1
|
||||||
for epoch in range(self.epochs):
|
for epoch in range(self.start_epoch, self.epochs):
|
||||||
|
self.epoch = epoch
|
||||||
self.trigger_callbacks("on_epoch_start")
|
self.trigger_callbacks("on_epoch_start")
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
@ -257,11 +258,10 @@ class BaseTrainer:
|
|||||||
self.save_metrics(metrics=log_vals)
|
self.save_metrics(metrics=log_vals)
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
|
if (not self.args.nosave) or (epoch + 1 == self.epochs):
|
||||||
self.save_model()
|
self.save_model()
|
||||||
self.trigger_callbacks('on_model_save')
|
self.trigger_callbacks('on_model_save')
|
||||||
|
|
||||||
self.epoch += 1
|
|
||||||
tnow = time.time()
|
tnow = time.time()
|
||||||
self.epoch_time = tnow - self.epoch_time_start
|
self.epoch_time = tnow - self.epoch_time_start
|
||||||
self.epoch_time_start = tnow
|
self.epoch_time_start = tnow
|
||||||
@ -301,17 +301,21 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
return data["train"], data.get("val") or data.get("test")
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
def get_model(self, model: Union[str, Path]):
|
def setup_model(self):
|
||||||
"""
|
"""
|
||||||
load/create/download model for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
pretrained = True
|
model = self.args.model
|
||||||
if str(model).endswith(".yaml"):
|
pretrained = not (str(model).endswith(".yaml"))
|
||||||
|
# config
|
||||||
|
if not pretrained:
|
||||||
model = check_file(model)
|
model = check_file(model)
|
||||||
pretrained = False
|
ckpt = self.load_ckpt(model) if pretrained else None
|
||||||
return self.load_model(model_cfg=None if pretrained else model,
|
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
|
||||||
weights=get_model(model) if pretrained else None,
|
return ckpt
|
||||||
data=self.data) # model
|
|
||||||
|
def load_ckpt(self, ckpt):
|
||||||
|
return torch.load(ckpt, map_location='cpu')
|
||||||
|
|
||||||
def optimizer_step(self):
|
def optimizer_step(self):
|
||||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||||
@ -350,7 +354,7 @@ class BaseTrainer:
|
|||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
self.console.info(text)
|
self.console.info(text)
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights):
|
||||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
@ -409,6 +413,40 @@ class BaseTrainer:
|
|||||||
if f is self.best:
|
if f is self.best:
|
||||||
self.console.info(f'\nValidating {f}...')
|
self.console.info(f'\nValidating {f}...')
|
||||||
|
|
||||||
|
def check_resume(self):
|
||||||
|
resume = self.args.resume
|
||||||
|
if resume:
|
||||||
|
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
|
||||||
|
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
||||||
|
if args_yaml.is_file():
|
||||||
|
args = self._get_config(args_yaml) # replace
|
||||||
|
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def resume_training(self, ckpt):
|
||||||
|
if ckpt is None:
|
||||||
|
return
|
||||||
|
best_fitness = 0.0
|
||||||
|
start_epoch = ckpt['epoch'] + 1
|
||||||
|
if ckpt['optimizer'] is not None:
|
||||||
|
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||||
|
best_fitness = ckpt['best_fitness']
|
||||||
|
if self.ema and ckpt.get('ema'):
|
||||||
|
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||||
|
self.ema.updates = ckpt['updates']
|
||||||
|
if self.args.resume:
|
||||||
|
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||||
|
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
||||||
|
LOGGER.info(
|
||||||
|
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
||||||
|
if self.epochs < start_epoch:
|
||||||
|
LOGGER.info(
|
||||||
|
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||||
|
)
|
||||||
|
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||||
|
self.best_fitness = best_fitness
|
||||||
|
self.start_epoch = start_epoch
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||||
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
|
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
|
||||||
|
@ -33,6 +33,7 @@ overlap_mask: True # masks overlap
|
|||||||
mask_ratio: 4 # mask downsample ratio
|
mask_ratio: 4 # mask downsample ratio
|
||||||
# Classification
|
# Classification
|
||||||
dropout: False # use dropout
|
dropout: False # use dropout
|
||||||
|
resume: False
|
||||||
|
|
||||||
|
|
||||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import glob
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -74,3 +75,9 @@ def file_date(path=__file__):
|
|||||||
# Return human-readable file modification date, i.e. '2021-3-26'
|
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||||
return f'{t.year}-{t.month}-{t.day}'
|
return f'{t.year}-{t.month}-{t.day}'
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_run(search_dir='.'):
|
||||||
|
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
||||||
|
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||||
|
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||||
|
|
||||||
|
|
||||||
@ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights):
|
||||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||||
model = weights
|
model = weights
|
||||||
else:
|
else:
|
||||||
model = ClassificationModel(model_cfg, weights, data["nc"])
|
model = ClassificationModel(model_cfg, weights, self.data["nc"])
|
||||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
ClassificationModel.reshape_outputs(model, self.data["nc"])
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if not weights and hasattr(m, 'reset_parameters'):
|
if not weights and hasattr(m, 'reset_parameters'):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
@ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_ckpt(self, ckpt):
|
||||||
|
return get_model(ckpt)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||||
return build_classification_dataloader(path=dataset_path,
|
return build_classification_dataloader(path=dataset_path,
|
||||||
imgsz=self.args.img_size,
|
imgsz=self.args.img_size,
|
||||||
@ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
||||||
return loss, loss
|
return loss, loss
|
||||||
|
|
||||||
|
def check_resume(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def resume_training(self, ckpt):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
|
@ -15,10 +15,10 @@ from .val import DetectionValidator
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class DetectionTrainer(SegmentationTrainer):
|
class DetectionTrainer(SegmentationTrainer):
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights):
|
||||||
model = DetectionModel(model_cfg or weights["model"].yaml,
|
model = DetectionModel(model_cfg or weights["model"].yaml,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=data["nc"],
|
nc=self.data["nc"],
|
||||||
anchors=self.args.get("anchors"))
|
anchors=self.args.get("anchors"))
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
@ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights):
|
||||||
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=data["nc"],
|
nc=self.data["nc"],
|
||||||
anchors=self.args.get("anchors"))
|
anchors=self.args.get("anchors"))
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
@ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator):
|
|||||||
cls,
|
cls,
|
||||||
bboxes,
|
bboxes,
|
||||||
masks,
|
masks,
|
||||||
paths,
|
paths=paths,
|
||||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
names=self.names)
|
names=self.names)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user