From c0a9660310db038bd83072283fbb3a5eeef5e448 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 21 Aug 2023 20:54:05 +0200 Subject: [PATCH] `ultralytics 8.0.159` add Classify training `resume` feature (#4482) --- ultralytics/__init__.py | 2 +- ultralytics/cfg/__init__.py | 2 +- ultralytics/models/yolo/classify/train.py | 10 +++------- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index ad8fa441..87f2b675 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.158' +__version__ = '8.0.159' from ultralytics.hub import start from ultralytics.models import RTDETR, SAM, YOLO diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 25f8016b..a409ffca 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -419,7 +419,7 @@ def entrypoint(debug=''): overrides['source'] = DEFAULT_CFG.source or ASSETS LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") elif mode in ('train', 'val'): - if 'data' not in overrides: + if 'data' not in overrides and 'resume' not in overrides: overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") elif mode == 'export': diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index 8c798f00..b09b5201 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer): if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed return - model = str(self.model) + model, ckpt = str(self.model), None # Load a YOLO model locally, from torchvision, or from Ultralytics assets if model.endswith('.pt'): - self.model, _ = attempt_load_one_weight(model, device='cpu') + self.model, ckpt = attempt_load_one_weight(model, device='cpu') for p in self.model.parameters(): p.requires_grad = True # for training elif model.split('.')[-1] in ('yaml', 'yml'): @@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer): FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') ClassificationModel.reshape_outputs(self.model, self.data['nc']) - return # do not return ckpt. Classification doesn't support resume + return ckpt def build_dataset(self, img_path, mode='train', batch=None): return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train') @@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer): loss_items = [round(float(loss_items), 5)] return dict(zip(keys, loss_items)) - def resume_training(self, ckpt): - """Resumes training from a given checkpoint.""" - pass - def plot_metrics(self): """Plots metrics from a CSV file.""" plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png