mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix resume (#138)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
82c849c163
commit
340376f7a6
@ -293,6 +293,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
|
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
|
||||||
|
if not hasattr(ckpt, 'stride'):
|
||||||
|
ckpt.stride = torch.tensor([32.])
|
||||||
|
|
||||||
# Append
|
# Append
|
||||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
@ -136,6 +136,15 @@ class YOLODataset(BaseDataset):
|
|||||||
batch_idx=True))
|
batch_idx=True))
|
||||||
return transforms
|
return transforms
|
||||||
|
|
||||||
|
def close_mosaic(self, hyp):
|
||||||
|
self.transforms = affine_transforms(self.imgsz, hyp)
|
||||||
|
self.transforms.append(
|
||||||
|
Format(bbox_format="xywh",
|
||||||
|
normalize=True,
|
||||||
|
return_mask=self.use_segments,
|
||||||
|
return_keypoint=self.use_keypoints,
|
||||||
|
batch_idx=True))
|
||||||
|
|
||||||
def update_labels_info(self, label):
|
def update_labels_info(self, label):
|
||||||
"""custom your label format here"""
|
"""custom your label format here"""
|
||||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||||
|
@ -15,6 +15,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import OmegaConf # noqa
|
from omegaconf import OmegaConf # noqa
|
||||||
|
from omegaconf import open_dict
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
@ -90,10 +91,15 @@ class BaseTrainer:
|
|||||||
# Dirs
|
# Dirs
|
||||||
project = self.args.project or f"runs/{self.args.task}"
|
project = self.args.project or f"runs/{self.args.task}"
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
self.save_dir = Path(
|
||||||
|
self.args.get(
|
||||||
|
"save_dir",
|
||||||
|
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)))
|
||||||
self.wdir = self.save_dir / 'weights' # weights dir
|
self.wdir = self.save_dir / 'weights' # weights dir
|
||||||
if RANK in {-1, 0}:
|
if RANK in {-1, 0}:
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
with open_dict(self.args):
|
||||||
|
self.args.save_dir = str(self.save_dir)
|
||||||
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
|
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||||
|
|
||||||
@ -131,6 +137,7 @@ class BaseTrainer:
|
|||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.loss_names = None
|
self.loss_names = None
|
||||||
self.csv = self.save_dir / 'results.csv'
|
self.csv = self.save_dir / 'results.csv'
|
||||||
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
@ -199,7 +206,6 @@ class BaseTrainer:
|
|||||||
else:
|
else:
|
||||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
self.resume_training(ckpt)
|
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
|
|
||||||
# dataloaders
|
# dataloaders
|
||||||
@ -211,6 +217,7 @@ class BaseTrainer:
|
|||||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
|
self.resume_training(ckpt)
|
||||||
self.run_callbacks("on_pretrain_routine_end")
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
@ -230,6 +237,9 @@ class BaseTrainer:
|
|||||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
f"Starting training for {self.epochs} epochs...")
|
f"Starting training for {self.epochs} epochs...")
|
||||||
|
if self.args.close_mosaic:
|
||||||
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||||
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||||
for epoch in range(self.start_epoch, self.epochs):
|
for epoch in range(self.start_epoch, self.epochs):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.run_callbacks("on_train_epoch_start")
|
self.run_callbacks("on_train_epoch_start")
|
||||||
@ -237,19 +247,21 @@ class BaseTrainer:
|
|||||||
if rank != -1:
|
if rank != -1:
|
||||||
self.train_loader.sampler.set_epoch(epoch)
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
pbar = enumerate(self.train_loader)
|
pbar = enumerate(self.train_loader)
|
||||||
|
# Update dataloader attributes (optional)
|
||||||
|
if epoch == (self.epochs - self.args.close_mosaic):
|
||||||
|
self.console.info("Closing dataloader mosaic")
|
||||||
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
|
self.train_loader.dataset.mosaic = False
|
||||||
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
self.console.info(self.progress_string())
|
self.console.info(self.progress_string())
|
||||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
|
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
for i, batch in pbar:
|
for i, batch in pbar:
|
||||||
self.run_callbacks("on_train_batch_start")
|
self.run_callbacks("on_train_batch_start")
|
||||||
|
|
||||||
# Update dataloader attributes (optional)
|
|
||||||
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
|
|
||||||
LOGGER.info("Closing dataloader mosaic")
|
|
||||||
self.train_loader.dataset.mosaic = False
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
ni = i + nb * epoch
|
ni = i + nb * epoch
|
||||||
if ni <= nw:
|
if ni <= nw:
|
||||||
@ -289,7 +301,7 @@ class BaseTrainer:
|
|||||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
||||||
self.run_callbacks('on_batch_end')
|
self.run_callbacks('on_batch_end')
|
||||||
if self.args.plots and ni < 3:
|
if self.args.plots and ni in self.plot_idx:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
|
|
||||||
self.run_callbacks("on_train_batch_end")
|
self.run_callbacks("on_train_batch_end")
|
||||||
@ -367,7 +379,8 @@ class BaseTrainer:
|
|||||||
if not pretrained:
|
if not pretrained:
|
||||||
model = check_file(model)
|
model = check_file(model)
|
||||||
ckpt = self.load_ckpt(model) if pretrained else None
|
ckpt = self.load_ckpt(model) if pretrained else None
|
||||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model
|
weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts
|
||||||
|
self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights)
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def load_ckpt(self, ckpt):
|
def load_ckpt(self, ckpt):
|
||||||
@ -479,8 +492,9 @@ class BaseTrainer:
|
|||||||
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
||||||
if args_yaml.is_file():
|
if args_yaml.is_file():
|
||||||
args = get_config(args_yaml) # replace
|
args = get_config(args_yaml) # replace
|
||||||
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
args.model, resume = str(last), True # reinstate
|
||||||
self.args = args
|
self.args = args
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
def resume_training(self, ckpt):
|
def resume_training(self, ckpt):
|
||||||
if ckpt is None:
|
if ckpt is None:
|
||||||
@ -493,7 +507,7 @@ class BaseTrainer:
|
|||||||
if self.ema and ckpt.get('ema'):
|
if self.ema and ckpt.get('ema'):
|
||||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||||
self.ema.updates = ckpt['updates']
|
self.ema.updates = ckpt['updates']
|
||||||
if self.args.resume:
|
if self.resume:
|
||||||
assert start_epoch > 0, \
|
assert start_epoch > 0, \
|
||||||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||||
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
||||||
|
@ -111,6 +111,7 @@ class BaseValidator:
|
|||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
self.dataloader = self.dataloader or \
|
self.dataloader = self.dataloader or \
|
||||||
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||||
|
self.data = data
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -24,11 +24,12 @@ def find_free_network_port() -> int:
|
|||||||
def generate_ddp_file(trainer):
|
def generate_ddp_file(trainer):
|
||||||
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
||||||
|
|
||||||
|
if not trainer.resume:
|
||||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||||
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
|
content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__":
|
||||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||||
|
|
||||||
trainer = {trainer.__class__.__name__}(overrides=overrides)
|
trainer = {trainer.__class__.__name__}(config=config)
|
||||||
trainer.train()'''
|
trainer.train()'''
|
||||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||||
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from copy import copy
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -64,7 +66,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
return v8.detect.DetectionValidator(self.test_loader,
|
return v8.detect.DetectionValidator(self.test_loader,
|
||||||
save_dir=self.save_dir,
|
save_dir=self.save_dir,
|
||||||
logger=self.console,
|
logger=self.console,
|
||||||
args=self.args)
|
args=copy(self.args))
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
if not hasattr(self, 'compute_loss'):
|
if not hasattr(self, 'compute_loss'):
|
||||||
|
@ -42,7 +42,6 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
head = model.model[-1] if self.training else model.model.model[-1]
|
head = model.model[-1] if self.training else model.model.model[-1]
|
||||||
if self.data:
|
|
||||||
self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset
|
self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset
|
||||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from copy import copy
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -27,7 +29,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
|||||||
return v8.segment.SegmentationValidator(self.test_loader,
|
return v8.segment.SegmentationValidator(self.test_loader,
|
||||||
save_dir=self.save_dir,
|
save_dir=self.save_dir,
|
||||||
logger=self.console,
|
logger=self.console,
|
||||||
args=self.args)
|
args=copy(self.args))
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
if not hasattr(self, 'compute_loss'):
|
if not hasattr(self, 'compute_loss'):
|
||||||
|
@ -37,7 +37,6 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
head = model.model[-1] if self.training else model.model.model[-1]
|
head = model.model[-1] if self.training else model.model.model[-1]
|
||||||
if self.data:
|
|
||||||
self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset
|
self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset
|
||||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||||
|
Loading…
x
Reference in New Issue
Block a user