mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add train loss and lr to loggers (#6732)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
4425abce59
commit
9618025416
@ -15,6 +15,10 @@ keywords: Ultralytics, MLflow, Callbacks, on_pretrain_routine_end, on_train_end,
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.callbacks.mlflow.on_train_epoch_end
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.callbacks.mlflow.on_fit_epoch_end
|
||||
|
||||
<br><br>
|
||||
|
@ -27,7 +27,7 @@ keywords: Ultralytics, YOLO, documentation, callback utilities, log_scalars, on_
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.callbacks.tensorboard.on_batch_end
|
||||
## ::: ultralytics.utils.callbacks.tensorboard.on_train_epoch_end
|
||||
|
||||
<br><br>
|
||||
|
||||
|
@ -90,8 +90,10 @@ def on_train_epoch_end(trainer):
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
# Report the current training progress
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix='train').items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.lr.items():
|
||||
task.get_logger().report_scalar('lr', k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -102,6 +104,8 @@ def on_fit_epoch_end(trainer):
|
||||
series='Epoch Time',
|
||||
value=trainer.epoch_time,
|
||||
iteration=trainer.epoch)
|
||||
for k, v in trainer.metrics.items():
|
||||
task.get_logger().report_scalar('val', k, v, iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
|
@ -32,7 +32,9 @@ try:
|
||||
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
from pathlib import Path
|
||||
|
||||
PREFIX = colorstr('MLflow: ')
|
||||
SANITIZE = lambda x: {k.replace('(', '').replace(')', ''): float(v) for k, v in x.items()}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
@ -81,11 +83,18 @@ def on_pretrain_routine_end(trainer):
|
||||
f'{PREFIX}WARNING ⚠️ Not tracking this run')
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each train epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.label_loss_items(trainer.tloss, prefix='train')),
|
||||
step=trainer.epoch)
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.lr), step=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each fit epoch to MLflow."""
|
||||
if mlflow:
|
||||
sanitized_metrics = {k.replace('(', '').replace(')', ''): float(v) for k, v in trainer.metrics.items()}
|
||||
mlflow.log_metrics(metrics=sanitized_metrics, step=trainer.epoch)
|
||||
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
|
@ -58,9 +58,10 @@ def on_train_start(trainer):
|
||||
_log_tensorboard_graph(trainer)
|
||||
|
||||
|
||||
def on_batch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training batch."""
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -72,4 +73,4 @@ callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_start': on_train_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_batch_end': on_batch_end} if SummaryWriter else {}
|
||||
'on_train_epoch_end': on_train_epoch_end} if SummaryWriter else {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user