From fbeeb5d1e10ae01950fdf3b000fb806ee8ebf5ee Mon Sep 17 00:00:00 2001
From: Laughing <61612323+Laughing-q@users.noreply.github.com>
Date: Mon, 5 Dec 2022 20:56:41 -0600
Subject: [PATCH] add resuming (#63)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 ultralytics/yolo/engine/trainer.py          | 82 +++++++++++++++------
 ultralytics/yolo/utils/configs/default.yaml |  1 +
 ultralytics/yolo/utils/files.py             |  7 ++
 ultralytics/yolo/v8/classify/train.py       | 16 +++-
 ultralytics/yolo/v8/detect/train.py         |  4 +-
 ultralytics/yolo/v8/segment/train.py        |  4 +-
 ultralytics/yolo/v8/segment/val.py          |  2 +-
 7 files changed, 86 insertions(+), 30 deletions(-)

diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 0899941b..8d461bef 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
 from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
 from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
 from ultralytics.yolo.utils.checks import check_file, print_args
-from ultralytics.yolo.utils.files import increment_path, save_yaml
-from ultralytics.yolo.utils.modeling import get_model
+from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
 from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
 
 DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@@ -38,6 +37,7 @@ class BaseTrainer:
 
     def __init__(self, config=DEFAULT_CONFIG, overrides={}):
         self.args = self._get_config(config, overrides)
+        self.check_resume()
         init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
 
         self.console = LOGGER
@@ -50,6 +50,7 @@ class BaseTrainer:
         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths
         self.batch_size = self.args.batch_size
         self.epochs = self.args.epochs
+        self.start_epoch = 0
         print_args(dict(self.args))
 
         # Save run settings
@@ -66,8 +67,6 @@ class BaseTrainer:
         else:
             self.data = check_dataset(self.data)
         self.trainset, self.testset = self.get_dataset(self.data)
-        if self.args.model:
-            self.model = self.get_model(self.args.model)
         self.ema = None
 
         # Optimization utils init
@@ -136,15 +135,17 @@ class BaseTrainer:
         self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
 
         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
-        self.model = self.model.to(self.device)
-        self.model = DDP(self.model, device_ids=[rank])
 
     def _setup_train(self, rank, world_size):
         """
         Builds dataloaders and optimizer on correct rank process
         """
-        # Optimizer
+        # model
+        ckpt = self.setup_model()
         self.set_model_attributes()
+        if world_size > 1:
+            self.model = DDP(self.model, device_ids=[rank])
+        # Optimizer
         self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
         self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
         self.optimizer = build_optimizer(model=self.model,
@@ -158,6 +159,8 @@ class BaseTrainer:
         else:
             self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear
         self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+        self.resume_training(ckpt)
+        self.scheduler.last_epoch = self.start_epoch - 1  # do not move
 
         # dataloaders
         batch_size = self.batch_size // world_size
@@ -174,20 +177,18 @@ class BaseTrainer:
     def _do_train(self, rank=-1, world_size=1):
         if world_size > 1:
             self._setup_ddp(rank, world_size)
-        else:
-            self.model = self.model.to(self.device)
 
-        self.trigger_callbacks("before_train")
         self._setup_train(rank, world_size)
+        self.trigger_callbacks("before_train")
 
-        self.epoch = 0
         self.epoch_time = None
         self.epoch_time_start = time.time()
         self.train_time_start = time.time()
         nb = len(self.train_loader)  # number of batches
         nw = max(round(self.args.warmup_epochs * nb), 100)  # number of warmup iterations
         last_opt_step = -1
-        for epoch in range(self.epochs):
+        for epoch in range(self.start_epoch, self.epochs):
+            self.epoch = epoch
             self.trigger_callbacks("on_epoch_start")
             self.model.train()
             if rank != -1:
@@ -257,11 +258,10 @@ class BaseTrainer:
                 self.save_metrics(metrics=log_vals)
 
                 # save model
-                if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
+                if (not self.args.nosave) or (epoch + 1 == self.epochs):
                     self.save_model()
                     self.trigger_callbacks('on_model_save')
 
-            self.epoch += 1
             tnow = time.time()
             self.epoch_time = tnow - self.epoch_time_start
             self.epoch_time_start = tnow
@@ -301,17 +301,21 @@ class BaseTrainer:
         """
         return data["train"], data.get("val") or data.get("test")
 
-    def get_model(self, model: Union[str, Path]):
+    def setup_model(self):
         """
         load/create/download model for any task
         """
-        pretrained = True
-        if str(model).endswith(".yaml"):
+        model = self.args.model
+        pretrained = not (str(model).endswith(".yaml"))
+        # config
+        if not pretrained:
             model = check_file(model)
-            pretrained = False
-        return self.load_model(model_cfg=None if pretrained else model,
-                               weights=get_model(model) if pretrained else None,
-                               data=self.data)  # model
+        ckpt = self.load_ckpt(model) if pretrained else None
+        self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device)  # model
+        return ckpt
+
+    def load_ckpt(self, ckpt):
+        return torch.load(ckpt, map_location='cpu')
 
     def optimizer_step(self):
         self.scaler.unscale_(self.optimizer)  # unscale gradients
@@ -350,7 +354,7 @@ class BaseTrainer:
         if rank in {-1, 0}:
             self.console.info(text)
 
-    def load_model(self, model_cfg, weights, data):
+    def load_model(self, model_cfg, weights):
         raise NotImplementedError("This task trainer doesn't support loading cfg files")
 
     def get_validator(self):
@@ -409,6 +413,40 @@ class BaseTrainer:
                 if f is self.best:
                     self.console.info(f'\nValidating {f}...')
 
+    def check_resume(self):
+        resume = self.args.resume
+        if resume:
+            last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
+            args_yaml = last.parent.parent / 'args.yaml'  # train options yaml
+            if args_yaml.is_file():
+                args = self._get_config(args_yaml)  # replace
+            args.model, args.resume, args.exist_ok = str(last), True, True  # reinstate
+            self.args = args
+
+    def resume_training(self, ckpt):
+        if ckpt is None:
+            return
+        best_fitness = 0.0
+        start_epoch = ckpt['epoch'] + 1
+        if ckpt['optimizer'] is not None:
+            self.optimizer.load_state_dict(ckpt['optimizer'])  # optimizer
+            best_fitness = ckpt['best_fitness']
+        if self.ema and ckpt.get('ema'):
+            self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA
+            self.ema.updates = ckpt['updates']
+        if self.args.resume:
+            assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
+                                    f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
+            LOGGER.info(
+                f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
+        if self.epochs < start_epoch:
+            LOGGER.info(
+                f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
+            )
+            self.epochs += ckpt['epoch']  # finetune additional epochs
+        self.best_fitness = best_fitness
+        self.start_epoch = start_epoch
+
 
 def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
     # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml
index 9dd3ab79..348b3977 100644
--- a/ultralytics/yolo/utils/configs/default.yaml
+++ b/ultralytics/yolo/utils/configs/default.yaml
@@ -33,6 +33,7 @@ overlap_mask: True  # masks overlap
 mask_ratio: 4  # mask downsample ratio
 # Classification
 dropout: False # use dropout
+resume: False
 
 
 # Val/Test settings ----------------------------------------------------------------------------------------------------
diff --git a/ultralytics/yolo/utils/files.py b/ultralytics/yolo/utils/files.py
index 2ae98120..0e97491d 100644
--- a/ultralytics/yolo/utils/files.py
+++ b/ultralytics/yolo/utils/files.py
@@ -1,4 +1,5 @@
 import contextlib
+import glob
 import os
 from datetime import datetime
 from pathlib import Path
@@ -74,3 +75,9 @@ def file_date(path=__file__):
     # Return human-readable file modification date, i.e. '2021-3-26'
     t = datetime.fromtimestamp(Path(path).stat().st_mtime)
     return f'{t.year}-{t.month}-{t.day}'
+
+
+def get_latest_run(search_dir='.'):
+    # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
+    last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
+    return max(last_list, key=os.path.getctime) if last_list else ''
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
index 370c4ad4..813278dc 100644
--- a/ultralytics/yolo/v8/classify/train.py
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -4,6 +4,7 @@ import torch
 from ultralytics.yolo import v8
 from ultralytics.yolo.data import build_classification_dataloader
 from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
+from ultralytics.yolo.utils.modeling import get_model
 from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
 
 
@@ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer):
     def set_model_attributes(self):
         self.model.names = self.data["names"]
 
-    def load_model(self, model_cfg, weights, data):
+    def load_model(self, model_cfg, weights):
         # TODO: why treat clf models as unique. We should have clf yamls?
         if weights and not weights.__class__.__name__.startswith("yolo"):  # torchvision
             model = weights
         else:
-            model = ClassificationModel(model_cfg, weights, data["nc"])
-        ClassificationModel.reshape_outputs(model, data["nc"])
+            model = ClassificationModel(model_cfg, weights, self.data["nc"])
+        ClassificationModel.reshape_outputs(model, self.data["nc"])
         for m in model.modules():
             if not weights and hasattr(m, 'reset_parameters'):
                 m.reset_parameters()
@@ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer):
             p.requires_grad = True  # for training
         return model
 
+    def load_ckpt(self, ckpt):
+        return get_model(ckpt)
+
     def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
         return build_classification_dataloader(path=dataset_path,
                                                imgsz=self.args.img_size,
@@ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer):
         loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
         return loss, loss
 
+    def check_resume(self):
+        pass
+
+    def resume_training(self, ckpt):
+        pass
+
 
 @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
 def train(cfg):
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
index 22e9e36e..8021b786 100644
--- a/ultralytics/yolo/v8/detect/train.py
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -15,10 +15,10 @@ from .val import DetectionValidator
 # BaseTrainer python usage
 class DetectionTrainer(SegmentationTrainer):
 
-    def load_model(self, model_cfg, weights, data):
+    def load_model(self, model_cfg, weights):
         model = DetectionModel(model_cfg or weights["model"].yaml,
                                ch=3,
-                               nc=data["nc"],
+                               nc=self.data["nc"],
                                anchors=self.args.get("anchors"))
         if weights:
             model.load(weights)
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index a5481d26..95bc417c 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer):
         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
         return batch
 
-    def load_model(self, model_cfg, weights, data):
+    def load_model(self, model_cfg, weights):
         model = SegmentationModel(model_cfg or weights["model"].yaml,
                                   ch=3,
-                                  nc=data["nc"],
+                                  nc=self.data["nc"],
                                   anchors=self.args.get("anchors"))
         if weights:
             model.load(weights)
diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py
index 3784fd3d..7ada26ce 100644
--- a/ultralytics/yolo/v8/segment/val.py
+++ b/ultralytics/yolo/v8/segment/val.py
@@ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator):
                               cls,
                               bboxes,
                               masks,
-                              paths,
+                              paths=paths,
                               fname=self.save_dir / f"val_batch{ni}_labels.jpg",
                               names=self.names)