mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
ultralytics 8.0.159
add Classify training resume
feature (#4482)
This commit is contained in:
parent
b2f279ffdd
commit
c0a9660310
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.158'
|
__version__ = '8.0.159'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
|
@ -419,7 +419,7 @@ def entrypoint(debug=''):
|
|||||||
overrides['source'] = DEFAULT_CFG.source or ASSETS
|
overrides['source'] = DEFAULT_CFG.source or ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ('train', 'val'):
|
elif mode in ('train', 'val'):
|
||||||
if 'data' not in overrides:
|
if 'data' not in overrides and 'resume' not in overrides:
|
||||||
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||||
elif mode == 'export':
|
elif mode == 'export':
|
||||||
|
@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||||
return
|
return
|
||||||
|
|
||||||
model = str(self.model)
|
model, ckpt = str(self.model), None
|
||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith('.pt'):
|
if model.endswith('.pt'):
|
||||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
self.model, ckpt = attempt_load_one_weight(model, device='cpu')
|
||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
elif model.split('.')[-1] in ('yaml', 'yml'):
|
elif model.split('.')[-1] in ('yaml', 'yml'):
|
||||||
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||||
|
|
||||||
return # do not return ckpt. Classification doesn't support resume
|
return ckpt
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='train', batch=None):
|
def build_dataset(self, img_path, mode='train', batch=None):
|
||||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
||||||
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
loss_items = [round(float(loss_items), 5)]
|
loss_items = [round(float(loss_items), 5)]
|
||||||
return dict(zip(keys, loss_items))
|
return dict(zip(keys, loss_items))
|
||||||
|
|
||||||
def resume_training(self, ckpt):
|
|
||||||
"""Resumes training from a given checkpoint."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plots metrics from a CSV file."""
|
"""Plots metrics from a CSV file."""
|
||||||
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||||
|
Loading…
x
Reference in New Issue
Block a user