mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 17:05:40 +08:00 
			
		
		
		
	ultralytics 8.1.4 RTDETR TensorBoard graph visualization fix (#7725)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									6535bcde2b
								
							
						
					
					
						commit
						7a0c27c7d7
					
				@ -93,7 +93,7 @@ dev = [
 | 
				
			|||||||
    "mkdocstrings[python]",
 | 
					    "mkdocstrings[python]",
 | 
				
			||||||
    "mkdocs-jupyter", # for notebooks
 | 
					    "mkdocs-jupyter", # for notebooks
 | 
				
			||||||
    "mkdocs-redirects", # for 301 redirects
 | 
					    "mkdocs-redirects", # for 301 redirects
 | 
				
			||||||
    "mkdocs-ultralytics-plugin>=0.0.34", # for meta descriptions and images, dates and authors
 | 
					    "mkdocs-ultralytics-plugin>=0.0.38", # for meta descriptions and images, dates and authors
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
export = [
 | 
					export = [
 | 
				
			||||||
    "onnx>=1.12.0", # ONNX export
 | 
					    "onnx>=1.12.0", # ONNX export
 | 
				
			||||||
 | 
				
			|||||||
@ -1,13 +1,25 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "8.1.3"
 | 
					__version__ = "8.1.4"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.data.explorer.explorer import Explorer
 | 
					from ultralytics.data.explorer.explorer import Explorer
 | 
				
			||||||
from ultralytics.models import RTDETR, SAM, YOLO
 | 
					from ultralytics.models import RTDETR, SAM, YOLO
 | 
				
			||||||
from ultralytics.models.fastsam import FastSAM
 | 
					from ultralytics.models.fastsam import FastSAM
 | 
				
			||||||
from ultralytics.models.nas import NAS
 | 
					from ultralytics.models.nas import NAS
 | 
				
			||||||
from ultralytics.utils import SETTINGS as settings
 | 
					from ultralytics.utils import ASSETS, SETTINGS as settings
 | 
				
			||||||
from ultralytics.utils.checks import check_yolo as checks
 | 
					from ultralytics.utils.checks import check_yolo as checks
 | 
				
			||||||
from ultralytics.utils.downloads import download
 | 
					from ultralytics.utils.downloads import download
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = "__version__", "YOLO", "NAS", "SAM", "FastSAM", "RTDETR", "checks", "download", "settings", "Explorer"
 | 
					__all__ = (
 | 
				
			||||||
 | 
					    "__version__",
 | 
				
			||||||
 | 
					    "ASSETS",
 | 
				
			||||||
 | 
					    "YOLO",
 | 
				
			||||||
 | 
					    "NAS",
 | 
				
			||||||
 | 
					    "SAM",
 | 
				
			||||||
 | 
					    "FastSAM",
 | 
				
			||||||
 | 
					    "RTDETR",
 | 
				
			||||||
 | 
					    "checks",
 | 
				
			||||||
 | 
					    "download",
 | 
				
			||||||
 | 
					    "settings",
 | 
				
			||||||
 | 
					    "Explorer",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
				
			|||||||
@ -63,7 +63,9 @@ download: |
 | 
				
			|||||||
  # Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link)
 | 
					  # Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link)
 | 
				
			||||||
  dir = Path(yaml['path'])  # dataset root dir
 | 
					  dir = Path(yaml['path'])  # dataset root dir
 | 
				
			||||||
  urls = ['https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link']
 | 
					  urls = ['https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link']
 | 
				
			||||||
  download(urls, dir=dir)
 | 
					  print("\n\nWARNING: Argoverse dataset MUST be downloaded manually, autodownload will NOT work.")
 | 
				
			||||||
 | 
					  print(f"WARNING: Manually download Argoverse dataset '{urls[0]}' to '{dir}' and re-run your command.\n\n")
 | 
				
			||||||
 | 
					  # download(urls, dir=dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Convert
 | 
					  # Convert
 | 
				
			||||||
  annotations_dir = 'Argoverse-HD/annotations/'
 | 
					  annotations_dir = 'Argoverse-HD/annotations/'
 | 
				
			||||||
 | 
				
			|||||||
@ -427,7 +427,9 @@ class Model(nn.Module):
 | 
				
			|||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def names(self):
 | 
					    def names(self):
 | 
				
			||||||
        """Returns class names of the loaded model."""
 | 
					        """Returns class names of the loaded model."""
 | 
				
			||||||
        return self.model.names if hasattr(self.model, "names") else None
 | 
					        from ultralytics.nn.autobackend import check_class_names
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return check_class_names(self.model.names) if hasattr(self.model, "names") else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def device(self):
 | 
					    def device(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -376,7 +376,7 @@ class RTDETRDecoder(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
 | 
					    def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
 | 
				
			||||||
        """Generates and prepares the input required for the decoder from the provided features and shapes."""
 | 
					        """Generates and prepares the input required for the decoder from the provided features and shapes."""
 | 
				
			||||||
        bs = len(feats)
 | 
					        bs = feats.shape[0]
 | 
				
			||||||
        # Prepare input for decoder
 | 
					        # Prepare input for decoder
 | 
				
			||||||
        anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
 | 
					        anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
 | 
				
			||||||
        features = self.enc_output(valid_mask * feats)  # bs, h*w, 256
 | 
					        features = self.enc_output(valid_mask * feats)  # bs, h*w, 256
 | 
				
			||||||
 | 
				
			|||||||
@ -101,10 +101,10 @@ class AIFI(TransformerEncoderLayer):
 | 
				
			|||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
 | 
					    def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
 | 
				
			||||||
        """Builds 2D sine-cosine position embedding."""
 | 
					        """Builds 2D sine-cosine position embedding."""
 | 
				
			||||||
        grid_w = torch.arange(int(w), dtype=torch.float32)
 | 
					 | 
				
			||||||
        grid_h = torch.arange(int(h), dtype=torch.float32)
 | 
					 | 
				
			||||||
        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
 | 
					 | 
				
			||||||
        assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
 | 
					        assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
 | 
				
			||||||
 | 
					        grid_w = torch.arange(w, dtype=torch.float32)
 | 
				
			||||||
 | 
					        grid_h = torch.arange(h, dtype=torch.float32)
 | 
				
			||||||
 | 
					        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
 | 
				
			||||||
        pos_dim = embed_dim // 4
 | 
					        pos_dim = embed_dim // 4
 | 
				
			||||||
        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
 | 
					        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
 | 
				
			||||||
        omega = 1.0 / (temperature**omega)
 | 
					        omega = 1.0 / (temperature**omega)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,14 +1,21 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
				
			||||||
 | 
					import contextlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
 | 
					from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    # WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
 | 
					    # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
 | 
				
			||||||
    from torch.utils.tensorboard import SummaryWriter
 | 
					    from torch.utils.tensorboard import SummaryWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert not TESTS_RUNNING  # do not log pytest
 | 
					    assert not TESTS_RUNNING  # do not log pytest
 | 
				
			||||||
    assert SETTINGS["tensorboard"] is True  # verify integration is enabled
 | 
					    assert SETTINGS["tensorboard"] is True  # verify integration is enabled
 | 
				
			||||||
    WRITER = None  # TensorBoard SummaryWriter instance
 | 
					    WRITER = None  # TensorBoard SummaryWriter instance
 | 
				
			||||||
 | 
					    PREFIX = colorstr("TensorBoard: ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Imports below only required if TensorBoard enabled
 | 
				
			||||||
 | 
					    import warnings
 | 
				
			||||||
 | 
					    from copy import deepcopy
 | 
				
			||||||
 | 
					    from ultralytics.utils.torch_utils import de_parallel, torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
except (ImportError, AssertionError, TypeError, AttributeError):
 | 
					except (ImportError, AssertionError, TypeError, AttributeError):
 | 
				
			||||||
    # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
 | 
					    # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
 | 
				
			||||||
@ -25,20 +32,37 @@ def _log_scalars(scalars, step=0):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def _log_tensorboard_graph(trainer):
 | 
					def _log_tensorboard_graph(trainer):
 | 
				
			||||||
    """Log model graph to TensorBoard."""
 | 
					    """Log model graph to TensorBoard."""
 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from ultralytics.utils.torch_utils import de_parallel, torch
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Input image
 | 
				
			||||||
    imgsz = trainer.args.imgsz
 | 
					    imgsz = trainer.args.imgsz
 | 
				
			||||||
    imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
 | 
					    imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
 | 
				
			||||||
    p = next(trainer.model.parameters())  # for device, type
 | 
					    p = next(trainer.model.parameters())  # for device, type
 | 
				
			||||||
    im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype)  # input image (must be zeros, not empty)
 | 
					    im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype)  # input image (must be zeros, not empty)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with warnings.catch_warnings():
 | 
					    with warnings.catch_warnings():
 | 
				
			||||||
        warnings.simplefilter("ignore", category=UserWarning)  # suppress jit trace warning
 | 
					        warnings.simplefilter("ignore", category=UserWarning)  # suppress jit trace warning
 | 
				
			||||||
 | 
					        warnings.simplefilter("ignore", category=torch.jit.TracerWarning)  # suppress jit trace warning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Try simple method first (YOLO)
 | 
				
			||||||
 | 
					        with contextlib.suppress(Exception):
 | 
				
			||||||
            WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
 | 
					            WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
 | 
				
			||||||
 | 
					            LOGGER.info(f"{PREFIX}model graph visualization added ✅")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fallback to TorchScript export steps (RTDETR)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            model = deepcopy(de_parallel(trainer.model))
 | 
				
			||||||
 | 
					            model.eval()
 | 
				
			||||||
 | 
					            model = model.fuse(verbose=False)
 | 
				
			||||||
 | 
					            for m in model.modules():
 | 
				
			||||||
 | 
					                if hasattr(m, "export"):  # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
 | 
				
			||||||
 | 
					                    m.export = True
 | 
				
			||||||
 | 
					                    m.format = "torchscript"
 | 
				
			||||||
 | 
					            model(im)  # dry run
 | 
				
			||||||
 | 
					            WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
 | 
				
			||||||
 | 
					            LOGGER.info(f"{PREFIX}model graph visualization added ✅")
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
        LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
 | 
					            LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def on_pretrain_routine_start(trainer):
 | 
					def on_pretrain_routine_start(trainer):
 | 
				
			||||||
@ -47,10 +71,9 @@ def on_pretrain_routine_start(trainer):
 | 
				
			|||||||
        try:
 | 
					        try:
 | 
				
			||||||
            global WRITER
 | 
					            global WRITER
 | 
				
			||||||
            WRITER = SummaryWriter(str(trainer.save_dir))
 | 
					            WRITER = SummaryWriter(str(trainer.save_dir))
 | 
				
			||||||
            prefix = colorstr("TensorBoard: ")
 | 
					            LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
 | 
				
			||||||
            LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
 | 
					            LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def on_train_start(trainer):
 | 
					def on_train_start(trainer):
 | 
				
			||||||
 | 
				
			|||||||
@ -220,7 +220,7 @@ def non_max_suppression(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Settings
 | 
					    # Settings
 | 
				
			||||||
    # min_wh = 2  # (pixels) minimum box width and height
 | 
					    # min_wh = 2  # (pixels) minimum box width and height
 | 
				
			||||||
    time_limit = 0.5 + max_time_img * bs  # seconds to quit after
 | 
					    time_limit = 2.0 + max_time_img * bs  # seconds to quit after
 | 
				
			||||||
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
 | 
					    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
 | 
					    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user