mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add pred, export and val callbacks (#126)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
63c7a74691
commit
c6eb6720de
@ -56,6 +56,7 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@ -71,6 +72,7 @@ from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import check_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save
|
||||
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size, increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -142,8 +144,14 @@ class Exporter:
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# callbacks
|
||||
self.callbacks = defaultdict([])
|
||||
for callback, func in default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
self.run_callbacks("on_export_start")
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
||||
@ -245,6 +253,8 @@ class Exporter:
|
||||
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
|
||||
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
|
||||
self.run_callbacks("on_export_end")
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
@ -755,6 +765,22 @@ class Exporter:
|
||||
LOGGER.info(f'{prefix} pipeline success')
|
||||
return model
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
overrides the existing callbacks with the given callback
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def export(cfg):
|
||||
|
@ -26,6 +26,7 @@ Usage - formats:
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
|
||||
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
@ -89,6 +91,11 @@ class BasePredictor:
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
|
||||
# callbacks
|
||||
self.callbacks = defaultdict([])
|
||||
for callback, func in default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
|
||||
def preprocess(self, img):
|
||||
pass
|
||||
|
||||
@ -143,9 +150,11 @@ class BasePredictor:
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
model = self.model if self.done_setup else self.setup(source, model)
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
with self.dt[0]:
|
||||
@ -176,6 +185,8 @@ class BasePredictor:
|
||||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(
|
||||
@ -185,6 +196,8 @@ class BasePredictor:
|
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
@ -213,3 +226,19 @@ class BasePredictor:
|
||||
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
overrides the existing callbacks with the given callback
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
@ -136,20 +136,20 @@ class BaseTrainer:
|
||||
if RANK in {0, -1}:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
"""
|
||||
self.callbacks[onevent].append(callback)
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, onevent: str, callback):
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
overrides the existing callbacks with the given callback
|
||||
"""
|
||||
self.callbacks[onevent] = [callback]
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def trigger_callbacks(self, onevent: str):
|
||||
for callback in self.callbacks.get(onevent, []):
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
@ -178,7 +178,7 @@ class BaseTrainer:
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
"""
|
||||
# model
|
||||
self.trigger_callbacks("on_pretrain_routine_start")
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
@ -210,7 +210,7 @@ class BaseTrainer:
|
||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
self.trigger_callbacks("on_pretrain_routine_end")
|
||||
self.run_callbacks("on_pretrain_routine_end")
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
@ -224,14 +224,14 @@ class BaseTrainer:
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.trigger_callbacks("on_train_start")
|
||||
self.run_callbacks("on_train_start")
|
||||
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f"Starting training for {self.epochs} epochs...")
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.trigger_callbacks("on_train_epoch_start")
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
self.model.train()
|
||||
if rank != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
@ -242,7 +242,7 @@ class BaseTrainer:
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.trigger_callbacks("on_train_batch_start")
|
||||
self.run_callbacks("on_train_batch_start")
|
||||
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
@ -287,35 +287,34 @@ class BaseTrainer:
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
||||
self.trigger_callbacks('on_batch_end')
|
||||
self.run_callbacks('on_batch_end')
|
||||
if self.args.plots and ni < 3:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.trigger_callbacks("on_train_batch_end")
|
||||
self.run_callbacks("on_train_batch_end")
|
||||
|
||||
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.scheduler.step()
|
||||
self.trigger_callbacks("on_train_epoch_end")
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
|
||||
if rank in {-1, 0}:
|
||||
|
||||
# Validation
|
||||
self.trigger_callbacks('on_val_start')
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs)
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.trigger_callbacks('on_val_end')
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.trigger_callbacks('on_model_save')
|
||||
self.run_callbacks('on_model_save')
|
||||
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
# TODO: termination condition
|
||||
|
||||
if rank in {-1, 0}:
|
||||
@ -326,9 +325,9 @@ class BaseTrainer:
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
self.trigger_callbacks('on_train_end')
|
||||
self.run_callbacks('on_train_end')
|
||||
torch.cuda.empty_cache()
|
||||
self.trigger_callbacks('teardown')
|
||||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
ckpt = {
|
||||
@ -470,7 +469,7 @@ class BaseTrainer:
|
||||
self.validator.args.save_json = True
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.trigger_callbacks('on_val_end')
|
||||
self.run_callbacks('on_val_end')
|
||||
|
||||
def check_resume(self):
|
||||
resume = self.args.resume
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -8,6 +9,7 @@ from tqdm import tqdm
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.callbacks import default_callbacks
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -64,12 +66,18 @@ class BaseValidator:
|
||||
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# callbacks
|
||||
self.callbacks = defaultdict(list)
|
||||
for callback, func in default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained
|
||||
if trainer is passed (trainer gets priority).
|
||||
"""
|
||||
self.run_callbacks('on_val_start')
|
||||
self.training = trainer is not None
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
@ -116,6 +124,7 @@ class BaseValidator:
|
||||
self.init_metrics(de_parallel(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.batch_i = batch_i
|
||||
# pre-process
|
||||
with dt[0]:
|
||||
@ -139,10 +148,12 @@ class BaseValidator:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
self.run_callbacks('on_val_batch_end')
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.print_results()
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
@ -156,6 +167,22 @@ class BaseValidator:
|
||||
stats = self.eval_json(stats) # update stats
|
||||
return stats
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
appends the given callback
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
overrides the existing callbacks with the given callback
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||
|
||||
|
@ -1,3 +1,7 @@
|
||||
# Ultralytics YOLO base callbacks
|
||||
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
def on_pretrain_routine_start(trainer):
|
||||
pass
|
||||
|
||||
@ -34,26 +38,6 @@ def on_train_epoch_end(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_start(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_image_end(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
pass
|
||||
|
||||
@ -74,7 +58,51 @@ def teardown(trainer):
|
||||
pass
|
||||
|
||||
|
||||
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||
def on_val_start(validator):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(validator):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(validator):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
pass
|
||||
|
||||
|
||||
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||
def on_predict_start(predictor):
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_start(predictor):
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
pass
|
||||
|
||||
|
||||
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||
def on_export_start(exporter):
|
||||
pass
|
||||
|
||||
|
||||
def on_export_end(exporter):
|
||||
pass
|
||||
|
||||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_train_start': on_train_start,
|
||||
@ -84,16 +112,27 @@ default_callbacks = {
|
||||
'on_before_zero_grad': on_before_zero_grad,
|
||||
'on_train_batch_end': on_train_batch_end,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_val_start': on_val_start,
|
||||
'on_val_batch_start': on_val_batch_start,
|
||||
'on_val_image_end': on_val_image_end,
|
||||
'on_val_batch_end': on_val_batch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
|
||||
'on_model_save': on_model_save,
|
||||
'on_train_end': on_train_end,
|
||||
'on_params_update': on_params_update,
|
||||
'teardown': teardown}
|
||||
'teardown': teardown,
|
||||
|
||||
# Run in validator
|
||||
'on_val_start': on_val_start,
|
||||
'on_val_batch_start': on_val_batch_start,
|
||||
'on_val_batch_end': on_val_batch_end,
|
||||
'on_val_end': on_val_end,
|
||||
|
||||
# Run in predictor
|
||||
'on_predict_start': on_predict_start,
|
||||
'on_predict_batch_start': on_predict_batch_start,
|
||||
'on_predict_batch_end': on_predict_batch_end,
|
||||
'on_predict_end': on_predict_end,
|
||||
|
||||
# Run in exporter
|
||||
'on_export_start': on_export_start,
|
||||
'on_export_end': on_export_end}
|
||||
|
||||
|
||||
def add_integration_callbacks(trainer):
|
||||
|
@ -18,7 +18,7 @@ def _log_images(imgs_dict, group="", step=0):
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
# TODO: reuse existing task
|
||||
task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
|
||||
task = Task.init(project_name=trainer.args.project or "YOLOv8",
|
||||
task_name=trainer.args.name,
|
||||
tags=['YOLOv8'],
|
||||
output_uri=True,
|
||||
@ -32,7 +32,7 @@ def on_train_epoch_end(trainer):
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
|
||||
|
||||
|
||||
def on_val_end(trainer):
|
||||
def on_fit_epoch_end(trainer):
|
||||
if trainer.epoch == 0:
|
||||
model_info = {
|
||||
"Parameters": get_num_params(trainer.model),
|
||||
@ -50,5 +50,5 @@ def on_train_end(trainer):
|
||||
callbacks = {
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end} if clearml else {}
|
||||
|
@ -17,11 +17,11 @@ def on_batch_end(trainer):
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_val_end(trainer):
|
||||
def on_fit_epoch_end(trainer):
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = {
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_val_end": on_val_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_batch_end": on_batch_end}
|
||||
|
@ -9,12 +9,11 @@ except (ImportError, AssertionError):
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
|
||||
name=trainer.args.name,
|
||||
config=dict(trainer.args)) if not wandb.run else wandb.run
|
||||
wandb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=dict(
|
||||
trainer.args)) if not wandb.run else wandb.run
|
||||
|
||||
|
||||
def on_val_end(trainer):
|
||||
def on_fit_epoch_end(trainer):
|
||||
wandb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 0:
|
||||
model_info = {
|
||||
@ -42,5 +41,5 @@ def on_train_end(trainer):
|
||||
callbacks = {
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end} if wandb else {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user