mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-06-09 01:24:25 +08:00
Add pred, export and val callbacks (#126)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
63c7a74691
commit
c6eb6720de
@ -56,6 +56,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ from ultralytics.yolo.configs import get_config
|
|||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||||
from ultralytics.yolo.data.utils import check_dataset
|
from ultralytics.yolo.data.utils import check_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save
|
||||||
|
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size, increment_path
|
from ultralytics.yolo.utils.files import file_size, increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -142,8 +144,14 @@ class Exporter:
|
|||||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# callbacks
|
||||||
|
self.callbacks = defaultdict([])
|
||||||
|
for callback, func in default_callbacks.items():
|
||||||
|
self.add_callback(callback, func)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, model=None):
|
def __call__(self, model=None):
|
||||||
|
self.run_callbacks("on_export_start")
|
||||||
t = time.time()
|
t = time.time()
|
||||||
format = self.args.format.lower() # to lowercase
|
format = self.args.format.lower() # to lowercase
|
||||||
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
||||||
@ -245,6 +253,8 @@ class Exporter:
|
|||||||
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
|
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
|
||||||
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
|
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
|
||||||
f"\nVisualize: https://netron.app")
|
f"\nVisualize: https://netron.app")
|
||||||
|
|
||||||
|
self.run_callbacks("on_export_end")
|
||||||
return f # return list of exported files/dirs
|
return f # return list of exported files/dirs
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
@ -755,6 +765,22 @@ class Exporter:
|
|||||||
LOGGER.info(f'{prefix} pipeline success')
|
LOGGER.info(f'{prefix} pipeline success')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
appends the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def set_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
overrides the existing callbacks with the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||||
def export(cfg):
|
def export(cfg):
|
||||||
|
@ -26,6 +26,7 @@ Usage - formats:
|
|||||||
yolov8n_paddle_model # PaddlePaddle
|
yolov8n_paddle_model # PaddlePaddle
|
||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
|
|||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
|
||||||
|
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -89,6 +91,11 @@ class BasePredictor:
|
|||||||
self.annotator = None
|
self.annotator = None
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
|
|
||||||
|
# callbacks
|
||||||
|
self.callbacks = defaultdict([])
|
||||||
|
for callback, func in default_callbacks.items():
|
||||||
|
self.add_callback(callback, func)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -143,9 +150,11 @@ class BasePredictor:
|
|||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, source=None, model=None):
|
def __call__(self, source=None, model=None):
|
||||||
|
self.run_callbacks("on_predict_start")
|
||||||
model = self.model if self.done_setup else self.setup(source, model)
|
model = self.model if self.done_setup else self.setup(source, model)
|
||||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||||
for batch in self.dataset:
|
for batch in self.dataset:
|
||||||
|
self.run_callbacks("on_predict_batch_start")
|
||||||
path, im, im0s, vid_cap, s = batch
|
path, im, im0s, vid_cap, s = batch
|
||||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||||
with self.dt[0]:
|
with self.dt[0]:
|
||||||
@ -176,6 +185,8 @@ class BasePredictor:
|
|||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
|
self.run_callbacks("on_predict_batch_end")
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
@ -185,6 +196,8 @@ class BasePredictor:
|
|||||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||||
|
|
||||||
|
self.run_callbacks("on_predict_end")
|
||||||
|
|
||||||
def show(self, p):
|
def show(self, p):
|
||||||
im0 = self.annotator.result()
|
im0 = self.annotator.result()
|
||||||
if platform.system() == 'Linux' and p not in self.windows:
|
if platform.system() == 'Linux' and p not in self.windows:
|
||||||
@ -213,3 +226,19 @@ class BasePredictor:
|
|||||||
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
self.vid_writer[idx].write(im0)
|
self.vid_writer[idx].write(im0)
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
appends the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def set_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
overrides the existing callbacks with the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
@ -136,20 +136,20 @@ class BaseTrainer:
|
|||||||
if RANK in {0, -1}:
|
if RANK in {0, -1}:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def add_callback(self, onevent: str, callback):
|
def add_callback(self, event: str, callback):
|
||||||
"""
|
"""
|
||||||
appends the given callback
|
appends the given callback
|
||||||
"""
|
"""
|
||||||
self.callbacks[onevent].append(callback)
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
def set_callback(self, onevent: str, callback):
|
def set_callback(self, event: str, callback):
|
||||||
"""
|
"""
|
||||||
overrides the existing callbacks with the given callback
|
overrides the existing callbacks with the given callback
|
||||||
"""
|
"""
|
||||||
self.callbacks[onevent] = [callback]
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
def trigger_callbacks(self, onevent: str):
|
def run_callbacks(self, event: str):
|
||||||
for callback in self.callbacks.get(onevent, []):
|
for callback in self.callbacks.get(event, []):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
@ -178,7 +178,7 @@ class BaseTrainer:
|
|||||||
Builds dataloaders and optimizer on correct rank process
|
Builds dataloaders and optimizer on correct rank process
|
||||||
"""
|
"""
|
||||||
# model
|
# model
|
||||||
self.trigger_callbacks("on_pretrain_routine_start")
|
self.run_callbacks("on_pretrain_routine_start")
|
||||||
ckpt = self.setup_model()
|
ckpt = self.setup_model()
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
@ -210,7 +210,7 @@ class BaseTrainer:
|
|||||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
self.trigger_callbacks("on_pretrain_routine_end")
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -224,14 +224,14 @@ class BaseTrainer:
|
|||||||
nb = len(self.train_loader) # number of batches
|
nb = len(self.train_loader) # number of batches
|
||||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||||
last_opt_step = -1
|
last_opt_step = -1
|
||||||
self.trigger_callbacks("on_train_start")
|
self.run_callbacks("on_train_start")
|
||||||
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
f"Starting training for {self.epochs} epochs...")
|
f"Starting training for {self.epochs} epochs...")
|
||||||
for epoch in range(self.start_epoch, self.epochs):
|
for epoch in range(self.start_epoch, self.epochs):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.trigger_callbacks("on_train_epoch_start")
|
self.run_callbacks("on_train_epoch_start")
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
self.train_loader.sampler.set_epoch(epoch)
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
@ -242,7 +242,7 @@ class BaseTrainer:
|
|||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
for i, batch in pbar:
|
for i, batch in pbar:
|
||||||
self.trigger_callbacks("on_train_batch_start")
|
self.run_callbacks("on_train_batch_start")
|
||||||
|
|
||||||
# Update dataloader attributes (optional)
|
# Update dataloader attributes (optional)
|
||||||
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
|
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
@ -287,35 +287,34 @@ class BaseTrainer:
|
|||||||
pbar.set_description(
|
pbar.set_description(
|
||||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
||||||
self.trigger_callbacks('on_batch_end')
|
self.run_callbacks('on_batch_end')
|
||||||
if self.args.plots and ni < 3:
|
if self.args.plots and ni < 3:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
|
|
||||||
self.trigger_callbacks("on_train_batch_end")
|
self.run_callbacks("on_train_batch_end")
|
||||||
|
|
||||||
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
self.trigger_callbacks("on_train_epoch_end")
|
self.run_callbacks("on_train_epoch_end")
|
||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
self.trigger_callbacks('on_val_start')
|
|
||||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||||
final_epoch = (epoch + 1 == self.epochs)
|
final_epoch = (epoch + 1 == self.epochs)
|
||||||
if self.args.val or final_epoch:
|
if self.args.val or final_epoch:
|
||||||
self.metrics, self.fitness = self.validate()
|
self.metrics, self.fitness = self.validate()
|
||||||
self.trigger_callbacks('on_val_end')
|
|
||||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
if self.args.save or (epoch + 1 == self.epochs):
|
if self.args.save or (epoch + 1 == self.epochs):
|
||||||
self.save_model()
|
self.save_model()
|
||||||
self.trigger_callbacks('on_model_save')
|
self.run_callbacks('on_model_save')
|
||||||
|
|
||||||
tnow = time.time()
|
tnow = time.time()
|
||||||
self.epoch_time = tnow - self.epoch_time_start
|
self.epoch_time = tnow - self.epoch_time_start
|
||||||
self.epoch_time_start = tnow
|
self.epoch_time_start = tnow
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
# TODO: termination condition
|
# TODO: termination condition
|
||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
@ -326,9 +325,9 @@ class BaseTrainer:
|
|||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.plot_metrics()
|
self.plot_metrics()
|
||||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
self.trigger_callbacks('on_train_end')
|
self.run_callbacks('on_train_end')
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.trigger_callbacks('teardown')
|
self.run_callbacks('teardown')
|
||||||
|
|
||||||
def save_model(self):
|
def save_model(self):
|
||||||
ckpt = {
|
ckpt = {
|
||||||
@ -470,7 +469,7 @@ class BaseTrainer:
|
|||||||
self.validator.args.save_json = True
|
self.validator.args.save_json = True
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
self.metrics.pop('fitness', None)
|
self.metrics.pop('fitness', None)
|
||||||
self.trigger_callbacks('on_val_end')
|
self.run_callbacks('on_val_end')
|
||||||
|
|
||||||
def check_resume(self):
|
def check_resume(self):
|
||||||
resume = self.args.resume
|
resume = self.args.resume
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +9,7 @@ from tqdm import tqdm
|
|||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
|
||||||
|
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz
|
from ultralytics.yolo.utils.checks import check_imgsz
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -64,12 +66,18 @@ class BaseValidator:
|
|||||||
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# callbacks
|
||||||
|
self.callbacks = defaultdict(list)
|
||||||
|
for callback, func in default_callbacks.items():
|
||||||
|
self.add_callback(callback, func)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
"""
|
"""
|
||||||
Supports validation of a pre-trained model if passed or a model being trained
|
Supports validation of a pre-trained model if passed or a model being trained
|
||||||
if trainer is passed (trainer gets priority).
|
if trainer is passed (trainer gets priority).
|
||||||
"""
|
"""
|
||||||
|
self.run_callbacks('on_val_start')
|
||||||
self.training = trainer is not None
|
self.training = trainer is not None
|
||||||
if self.training:
|
if self.training:
|
||||||
self.device = trainer.device
|
self.device = trainer.device
|
||||||
@ -116,6 +124,7 @@ class BaseValidator:
|
|||||||
self.init_metrics(de_parallel(model))
|
self.init_metrics(de_parallel(model))
|
||||||
self.jdict = [] # empty before each val
|
self.jdict = [] # empty before each val
|
||||||
for batch_i, batch in enumerate(bar):
|
for batch_i, batch in enumerate(bar):
|
||||||
|
self.run_callbacks('on_val_batch_start')
|
||||||
self.batch_i = batch_i
|
self.batch_i = batch_i
|
||||||
# pre-process
|
# pre-process
|
||||||
with dt[0]:
|
with dt[0]:
|
||||||
@ -139,10 +148,12 @@ class BaseValidator:
|
|||||||
self.plot_val_samples(batch, batch_i)
|
self.plot_val_samples(batch, batch_i)
|
||||||
self.plot_predictions(batch, preds, batch_i)
|
self.plot_predictions(batch, preds, batch_i)
|
||||||
|
|
||||||
|
self.run_callbacks('on_val_batch_end')
|
||||||
stats = self.get_stats()
|
stats = self.get_stats()
|
||||||
self.check_stats(stats)
|
self.check_stats(stats)
|
||||||
self.print_results()
|
self.print_results()
|
||||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||||
|
self.run_callbacks('on_val_end')
|
||||||
if self.training:
|
if self.training:
|
||||||
model.float()
|
model.float()
|
||||||
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||||
@ -156,6 +167,22 @@ class BaseValidator:
|
|||||||
stats = self.eval_json(stats) # update stats
|
stats = self.eval_json(stats) # update stats
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
appends the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def set_callback(self, event: str, callback):
|
||||||
|
"""
|
||||||
|
overrides the existing callbacks with the given callback
|
||||||
|
"""
|
||||||
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||||
|
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
# Ultralytics YOLO base callbacks
|
||||||
|
|
||||||
|
|
||||||
|
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -34,26 +38,6 @@ def on_train_epoch_end(trainer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def on_val_start(trainer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_val_batch_start(trainer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_val_image_end(trainer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_val_batch_end(trainer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_val_end(trainer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -74,7 +58,51 @@ def teardown(trainer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||||
|
def on_val_start(validator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_val_batch_start(validator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_val_batch_end(validator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_val_end(validator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||||
|
def on_predict_start(predictor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_predict_batch_start(predictor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_predict_batch_end(predictor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_predict_end(predictor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||||
|
def on_export_start(exporter):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_export_end(exporter):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
default_callbacks = {
|
default_callbacks = {
|
||||||
|
# Run in trainer
|
||||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||||
'on_train_start': on_train_start,
|
'on_train_start': on_train_start,
|
||||||
@ -84,16 +112,27 @@ default_callbacks = {
|
|||||||
'on_before_zero_grad': on_before_zero_grad,
|
'on_before_zero_grad': on_before_zero_grad,
|
||||||
'on_train_batch_end': on_train_batch_end,
|
'on_train_batch_end': on_train_batch_end,
|
||||||
'on_train_epoch_end': on_train_epoch_end,
|
'on_train_epoch_end': on_train_epoch_end,
|
||||||
'on_val_start': on_val_start,
|
|
||||||
'on_val_batch_start': on_val_batch_start,
|
|
||||||
'on_val_image_end': on_val_image_end,
|
|
||||||
'on_val_batch_end': on_val_batch_end,
|
|
||||||
'on_val_end': on_val_end,
|
|
||||||
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
|
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
|
||||||
'on_model_save': on_model_save,
|
'on_model_save': on_model_save,
|
||||||
'on_train_end': on_train_end,
|
'on_train_end': on_train_end,
|
||||||
'on_params_update': on_params_update,
|
'on_params_update': on_params_update,
|
||||||
'teardown': teardown}
|
'teardown': teardown,
|
||||||
|
|
||||||
|
# Run in validator
|
||||||
|
'on_val_start': on_val_start,
|
||||||
|
'on_val_batch_start': on_val_batch_start,
|
||||||
|
'on_val_batch_end': on_val_batch_end,
|
||||||
|
'on_val_end': on_val_end,
|
||||||
|
|
||||||
|
# Run in predictor
|
||||||
|
'on_predict_start': on_predict_start,
|
||||||
|
'on_predict_batch_start': on_predict_batch_start,
|
||||||
|
'on_predict_batch_end': on_predict_batch_end,
|
||||||
|
'on_predict_end': on_predict_end,
|
||||||
|
|
||||||
|
# Run in exporter
|
||||||
|
'on_export_start': on_export_start,
|
||||||
|
'on_export_end': on_export_end}
|
||||||
|
|
||||||
|
|
||||||
def add_integration_callbacks(trainer):
|
def add_integration_callbacks(trainer):
|
||||||
|
@ -18,7 +18,7 @@ def _log_images(imgs_dict, group="", step=0):
|
|||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
# TODO: reuse existing task
|
# TODO: reuse existing task
|
||||||
task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
|
task = Task.init(project_name=trainer.args.project or "YOLOv8",
|
||||||
task_name=trainer.args.name,
|
task_name=trainer.args.name,
|
||||||
tags=['YOLOv8'],
|
tags=['YOLOv8'],
|
||||||
output_uri=True,
|
output_uri=True,
|
||||||
@ -32,7 +32,7 @@ def on_train_epoch_end(trainer):
|
|||||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
|
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
|
||||||
|
|
||||||
|
|
||||||
def on_val_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"Parameters": get_num_params(trainer.model),
|
"Parameters": get_num_params(trainer.model),
|
||||||
@ -50,5 +50,5 @@ def on_train_end(trainer):
|
|||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||||
"on_train_epoch_end": on_train_epoch_end,
|
"on_train_epoch_end": on_train_epoch_end,
|
||||||
"on_val_end": on_val_end,
|
"on_fit_epoch_end": on_fit_epoch_end,
|
||||||
"on_train_end": on_train_end} if clearml else {}
|
"on_train_end": on_train_end} if clearml else {}
|
||||||
|
@ -17,11 +17,11 @@ def on_batch_end(trainer):
|
|||||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
def on_val_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||||
"on_val_end": on_val_end,
|
"on_fit_epoch_end": on_fit_epoch_end,
|
||||||
"on_batch_end": on_batch_end}
|
"on_batch_end": on_batch_end}
|
||||||
|
@ -9,12 +9,11 @@ except (ImportError, AssertionError):
|
|||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
|
wandb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=dict(
|
||||||
name=trainer.args.name,
|
trainer.args)) if not wandb.run else wandb.run
|
||||||
config=dict(trainer.args)) if not wandb.run else wandb.run
|
|
||||||
|
|
||||||
|
|
||||||
def on_val_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
wandb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
wandb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
@ -42,5 +41,5 @@ def on_train_end(trainer):
|
|||||||
callbacks = {
|
callbacks = {
|
||||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||||
"on_train_epoch_end": on_train_epoch_end,
|
"on_train_epoch_end": on_train_epoch_end,
|
||||||
"on_val_end": on_val_end,
|
"on_fit_epoch_end": on_fit_epoch_end,
|
||||||
"on_train_end": on_train_end} if wandb else {}
|
"on_train_end": on_train_end} if wandb else {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user