mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add TensorBoard graph for model visualization (#4464)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
87ce15d383
commit
3acead7e79
@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationTrainer, image segmentation, object det
|
|||||||
---
|
---
|
||||||
## ::: ultralytics.models.yolo.segment.train.SegmentationTrainer
|
## ::: ultralytics.models.yolo.segment.train.SegmentationTrainer
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
---
|
|
||||||
## ::: ultralytics.models.yolo.segment.train.train
|
|
||||||
<br><br>
|
|
||||||
|
@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationValidator, model segmentation, image cl
|
|||||||
---
|
---
|
||||||
## ::: ultralytics.models.yolo.segment.val.SegmentationValidator
|
## ::: ultralytics.models.yolo.segment.val.SegmentationValidator
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
---
|
|
||||||
## ::: ultralytics.models.yolo.segment.val.val
|
|
||||||
<br><br>
|
|
||||||
|
@ -13,6 +13,10 @@ keywords: Ultralytics, YOLO, documentation, callback utilities, log_scalars, on_
|
|||||||
## ::: ultralytics.utils.callbacks.tensorboard._log_scalars
|
## ::: ultralytics.utils.callbacks.tensorboard._log_scalars
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
---
|
||||||
|
## ::: ultralytics.utils.callbacks.tensorboard._log_tensorboard_graph
|
||||||
|
<br><br>
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.utils.callbacks.tensorboard.on_pretrain_routine_start
|
## ::: ultralytics.utils.callbacks.tensorboard.on_pretrain_routine_start
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -56,14 +56,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|||||||
def plot_metrics(self):
|
def plot_metrics(self):
|
||||||
"""Plots training/val metrics."""
|
"""Plots training/val metrics."""
|
||||||
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
||||||
|
|
||||||
|
|
||||||
def train(cfg=DEFAULT_CFG):
|
|
||||||
"""Train a YOLO segmentation model based on passed arguments."""
|
|
||||||
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
|
|
||||||
trainer = SegmentationTrainer(overrides=args)
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
train()
|
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
|
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||||
from ultralytics.utils.plotting import output_to_target, plot_images
|
from ultralytics.utils.plotting import output_to_target, plot_images
|
||||||
@ -243,14 +243,3 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
def val(cfg=DEFAULT_CFG):
|
|
||||||
"""Validate trained YOLO model on validation data."""
|
|
||||||
args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
|
|
||||||
validator = SegmentationValidator(args=args)
|
|
||||||
validator(model=args['model'])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
val()
|
|
||||||
|
@ -12,24 +12,43 @@ try:
|
|||||||
except (ImportError, AssertionError, TypeError):
|
except (ImportError, AssertionError, TypeError):
|
||||||
SummaryWriter = None
|
SummaryWriter = None
|
||||||
|
|
||||||
writer = None # TensorBoard SummaryWriter instance
|
WRITER = None # TensorBoard SummaryWriter instance
|
||||||
|
|
||||||
|
|
||||||
def _log_scalars(scalars, step=0):
|
def _log_scalars(scalars, step=0):
|
||||||
"""Logs scalar values to TensorBoard."""
|
"""Logs scalar values to TensorBoard."""
|
||||||
if writer:
|
if WRITER:
|
||||||
for k, v in scalars.items():
|
for k, v in scalars.items():
|
||||||
writer.add_scalar(k, v, step)
|
WRITER.add_scalar(k, v, step)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_tensorboard_graph(trainer):
|
||||||
|
# Log model graph to TensorBoard
|
||||||
|
try:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||||
|
|
||||||
|
imgsz = trainer.args.imgsz
|
||||||
|
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||||
|
p = next(trainer.model.parameters()) # for device, type
|
||||||
|
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
|
||||||
|
with warnings.catch_warnings(category=UserWarning):
|
||||||
|
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||||
|
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
|
||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
"""Initialize TensorBoard logging with SummaryWriter."""
|
"""Initialize TensorBoard logging with SummaryWriter."""
|
||||||
if SummaryWriter:
|
if SummaryWriter:
|
||||||
try:
|
try:
|
||||||
global writer
|
global WRITER
|
||||||
writer = SummaryWriter(str(trainer.save_dir))
|
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||||
prefix = colorstr('TensorBoard: ')
|
prefix = colorstr('TensorBoard: ')
|
||||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||||
|
_log_tensorboard_graph(trainer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user