ultralytics 8.0.159 add Classify training resume feature (#4482)

This commit is contained in:
Glenn Jocher 2023-08-21 20:54:05 +02:00 committed by GitHub
parent b2f279ffdd
commit c0a9660310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 9 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.158'
__version__ = '8.0.159'
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO

View File

@ -419,7 +419,7 @@ def entrypoint(debug=''):
overrides['source'] = DEFAULT_CFG.source or ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
if 'data' not in overrides and 'resume' not in overrides:
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':

View File

@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = str(self.model)
model, ckpt = str(self.model), None
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, _ = attempt_load_one_weight(model, device='cpu')
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.split('.')[-1] in ('yaml', 'yml'):
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer):
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # do not return ckpt. Classification doesn't support resume
return ckpt
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer):
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def resume_training(self, ckpt):
"""Resumes training from a given checkpoint."""
pass
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png