Add model.eval() in TensorBoad graph visualization to avoid BN stats changes (#8629)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2024-03-06 09:22:28 +08:00 committed by GitHub
parent 9c42596145
commit 609a0cefbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,6 +45,7 @@ def _log_tensorboard_graph(trainer):
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return