diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py index 98107e8c..59024ee9 100644 --- a/ultralytics/utils/callbacks/tensorboard.py +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -45,6 +45,7 @@ def _log_tensorboard_graph(trainer): # Try simple method first (YOLO) with contextlib.suppress(Exception): + trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) LOGGER.info(f"{PREFIX}model graph visualization added ✅") return