From 609a0cefbf4393016b1b5fd3fb1f800a54493e76 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:22:28 +0800 Subject: [PATCH] Add `model.eval()` in TensorBoad graph visualization to avoid BN stats changes (#8629) Co-authored-by: Glenn Jocher --- ultralytics/utils/callbacks/tensorboard.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py index 98107e8c..59024ee9 100644 --- a/ultralytics/utils/callbacks/tensorboard.py +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -45,6 +45,7 @@ def _log_tensorboard_graph(trainer): # Try simple method first (YOLO) with contextlib.suppress(Exception): + trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) LOGGER.info(f"{PREFIX}model graph visualization added ✅") return