mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-27 03:45:39 +08:00 
			
		
		
		
	Implement all missing docstrings (#5298)
Co-authored-by: snyk-bot <snyk-bot@snyk.io> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									e7f0658744
								
							
						
					
					
						commit
						7fd5dcbd86
					
				| @ -0,0 +1 @@ | ||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||
| @ -140,7 +140,7 @@ class Exporter: | ||||
|         Args: | ||||
|             cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. | ||||
|             overrides (dict, optional): Configuration overrides. Defaults to None. | ||||
|             _callbacks (list, optional): List of callback functions. Defaults to None. | ||||
|             _callbacks (dict, optional): Dictionary of callback functions. Defaults to None. | ||||
|         """ | ||||
|         self.args = get_cfg(cfg, overrides) | ||||
|         if self.args.format.lower() in ('coreml', 'mlmodel'):  # fix attempt for protobuf<3.20.x errors | ||||
|  | ||||
| @ -9,14 +9,45 @@ from ultralytics.utils import DEFAULT_CFG, ops | ||||
| 
 | ||||
| 
 | ||||
| class FastSAMPredictor(DetectionPredictor): | ||||
|     """ | ||||
|     FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics | ||||
|     YOLO framework. | ||||
| 
 | ||||
|     This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM. | ||||
|     It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing | ||||
|     for single-class segmentation. | ||||
| 
 | ||||
|     Attributes: | ||||
|         cfg (dict): Configuration parameters for prediction. | ||||
|         overrides (dict, optional): Optional parameter overrides for custom behavior. | ||||
|         _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes FastSAMPredictor class by inheriting from DetectionPredictor and setting task to 'segment'.""" | ||||
|         """ | ||||
|         Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'. | ||||
| 
 | ||||
|         Args: | ||||
|             cfg (dict): Configuration parameters for prediction. | ||||
|             overrides (dict, optional): Optional parameter overrides for custom behavior. | ||||
|             _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. | ||||
|         """ | ||||
|         super().__init__(cfg, overrides, _callbacks) | ||||
|         self.args.task = 'segment' | ||||
| 
 | ||||
|     def postprocess(self, preds, img, orig_imgs): | ||||
|         """Postprocesses the predictions, applies non-max suppression, scales the boxes, and returns the results.""" | ||||
|         """ | ||||
|         Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image | ||||
|         size, and returns the final results. | ||||
| 
 | ||||
|         Args: | ||||
|             preds (list): The raw output predictions from the model. | ||||
|             img (torch.Tensor): The processed image tensor. | ||||
|             orig_imgs (list | torch.Tensor): The original image or list of images. | ||||
| 
 | ||||
|         Returns: | ||||
|             (list): A list of Results objects, each containing processed boxes, masks, and other metadata. | ||||
|         """ | ||||
|         p = ops.non_max_suppression( | ||||
|             preds[0], | ||||
|             self.args.conf, | ||||
|  | ||||
| @ -13,6 +13,15 @@ from ultralytics.utils import TQDM | ||||
| 
 | ||||
| 
 | ||||
| class FastSAMPrompt: | ||||
|     """ | ||||
|     Fast Segment Anything Model class for image annotation and visualization. | ||||
| 
 | ||||
|     Attributes: | ||||
|         device (str): Computing device ('cuda' or 'cpu'). | ||||
|         results: Object detection or segmentation results. | ||||
|         source: Source image or image path. | ||||
|         clip: CLIP model for linear assignment. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, source, results, device='cuda') -> None: | ||||
|         """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment.""" | ||||
| @ -92,6 +101,20 @@ class FastSAMPrompt: | ||||
|              better_quality=True, | ||||
|              retina=False, | ||||
|              with_contours=True): | ||||
|         """ | ||||
|         Plots annotations, bounding boxes, and points on images and saves the output. | ||||
| 
 | ||||
|         Args: | ||||
|             annotations (list): Annotations to be plotted. | ||||
|             output (str or Path): Output directory for saving the plots. | ||||
|             bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. | ||||
|             points (list, optional): Points to be plotted. Defaults to None. | ||||
|             point_label (list, optional): Labels for the points. Defaults to None. | ||||
|             mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True. | ||||
|             better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True. | ||||
|             retina (bool, optional): Whether to use retina mask. Defaults to False. | ||||
|             with_contours (bool, optional): Whether to plot contours. Defaults to True. | ||||
|         """ | ||||
|         pbar = TQDM(annotations, total=len(annotations)) | ||||
|         for ann in pbar: | ||||
|             result_name = os.path.basename(ann.path) | ||||
| @ -160,6 +183,20 @@ class FastSAMPrompt: | ||||
|         target_height=960, | ||||
|         target_width=960, | ||||
|     ): | ||||
|         """ | ||||
|         Quickly shows the mask annotations on the given matplotlib axis. | ||||
| 
 | ||||
|         Args: | ||||
|             annotation (array-like): Mask annotation. | ||||
|             ax (matplotlib.axes.Axes): Matplotlib axis. | ||||
|             random_color (bool, optional): Whether to use random color for masks. Defaults to False. | ||||
|             bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None. | ||||
|             points (list, optional): Points to be plotted. Defaults to None. | ||||
|             pointlabel (list, optional): Labels for the points. Defaults to None. | ||||
|             retinamask (bool, optional): Whether to use retina mask. Defaults to True. | ||||
|             target_height (int, optional): Target height for resizing. Defaults to 960. | ||||
|             target_width (int, optional): Target width for resizing. Defaults to 960. | ||||
|         """ | ||||
|         n, h, w = annotation.shape  # batch, height, width | ||||
| 
 | ||||
|         areas = np.sum(annotation, axis=(1, 2)) | ||||
|  | ||||
| @ -5,9 +5,35 @@ from ultralytics.utils.metrics import SegmentMetrics | ||||
| 
 | ||||
| 
 | ||||
| class FastSAMValidator(SegmentationValidator): | ||||
|     """ | ||||
|     Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework. | ||||
| 
 | ||||
|     Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class | ||||
|     sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled | ||||
|     to avoid errors during validation. | ||||
| 
 | ||||
|     Attributes: | ||||
|         dataloader: The data loader object used for validation. | ||||
|         save_dir (str): The directory where validation results will be saved. | ||||
|         pbar: A progress bar object. | ||||
|         args: Additional arguments for customization. | ||||
|         _callbacks: List of callback functions to be invoked during validation. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): | ||||
|         """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" | ||||
|         """ | ||||
|         Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics. | ||||
| 
 | ||||
|         Args: | ||||
|             dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. | ||||
|             save_dir (Path, optional): Directory to save results. | ||||
|             pbar (tqdm.tqdm): Progress bar for displaying progress. | ||||
|             args (SimpleNamespace): Configuration for the validator. | ||||
|             _callbacks (dict): Dictionary to store various callback functions. | ||||
| 
 | ||||
|         Notes: | ||||
|             Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. | ||||
|         """ | ||||
|         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||
|         self.args.task = 'segment' | ||||
|         self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors | ||||
|  | ||||
| @ -23,6 +23,26 @@ from .val import NASValidator | ||||
| 
 | ||||
| 
 | ||||
| class NAS(Model): | ||||
|     """ | ||||
|     YOLO NAS model for object detection. | ||||
| 
 | ||||
|     This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. | ||||
|     It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models. | ||||
| 
 | ||||
|     Example: | ||||
|         ```python | ||||
|         from ultralytics import NAS | ||||
| 
 | ||||
|         model = NAS('yolo_nas_s') | ||||
|         results = model.predict('ultralytics/assets/bus.jpg') | ||||
|         ``` | ||||
| 
 | ||||
|     Attributes: | ||||
|         model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'. | ||||
| 
 | ||||
|     Note: | ||||
|         YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model='yolo_nas_s.pt') -> None: | ||||
|         """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" | ||||
|  | ||||
| @ -8,6 +8,29 @@ from ultralytics.utils import ops | ||||
| 
 | ||||
| 
 | ||||
| class NASPredictor(BasePredictor): | ||||
|     """ | ||||
|     Ultralytics YOLO NAS Predictor for object detection. | ||||
| 
 | ||||
|     This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the | ||||
|     raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and | ||||
|     scaling the bounding boxes to fit the original image dimensions. | ||||
| 
 | ||||
|     Attributes: | ||||
|         args (Namespace): Namespace containing various configurations for post-processing. | ||||
| 
 | ||||
|     Example: | ||||
|         ```python | ||||
|         from ultralytics import NAS | ||||
| 
 | ||||
|         model = NAS('yolo_nas_s') | ||||
|         predictor = model.predictor | ||||
|         # Assumes that raw_preds, img, orig_imgs are available | ||||
|         results = predictor.postprocess(raw_preds, img, orig_imgs) | ||||
|         ``` | ||||
| 
 | ||||
|     Note: | ||||
|         Typically, this class is not instantiated directly. It is used internally within the `NAS` class. | ||||
|     """ | ||||
| 
 | ||||
|     def postprocess(self, preds_in, img, orig_imgs): | ||||
|         """Postprocess predictions and returns a list of Results objects.""" | ||||
|  | ||||
| @ -9,6 +9,30 @@ __all__ = ['NASValidator'] | ||||
| 
 | ||||
| 
 | ||||
| class NASValidator(DetectionValidator): | ||||
|     """ | ||||
|     Ultralytics YOLO NAS Validator for object detection. | ||||
| 
 | ||||
|     Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions | ||||
|     generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes, | ||||
|     ultimately producing the final detections. | ||||
| 
 | ||||
|     Attributes: | ||||
|         args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds. | ||||
|         lb (torch.Tensor): Optional tensor for multilabel NMS. | ||||
| 
 | ||||
|     Example: | ||||
|         ```python | ||||
|         from ultralytics import NAS | ||||
| 
 | ||||
|         model = NAS('yolo_nas_s') | ||||
|         validator = model.validator | ||||
|         # Assumes that raw_preds are available | ||||
|         final_preds = validator.postprocess(raw_preds) | ||||
|         ``` | ||||
| 
 | ||||
|     Note: | ||||
|         This class is generally not instantiated directly but is used internally within the `NAS` class. | ||||
|     """ | ||||
| 
 | ||||
|     def postprocess(self, preds_in): | ||||
|         """Apply Non-maximum suppression to prediction outputs.""" | ||||
|  | ||||
| @ -12,14 +12,19 @@ from ultralytics.utils import colorstr, ops | ||||
| __all__ = 'RTDETRValidator',  # tuple or list | ||||
| 
 | ||||
| 
 | ||||
| # TODO: Temporarily RT-DETR does not need padding. | ||||
| class RTDETRDataset(YOLODataset): | ||||
|     """ | ||||
|     Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class. | ||||
| 
 | ||||
|     This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for | ||||
|     real-time detection and tracking tasks. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, *args, data=None, **kwargs): | ||||
|         """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.""" | ||||
|         super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs) | ||||
| 
 | ||||
|     # NOTE: add stretch version load_image for rtdetr mosaic | ||||
|     # NOTE: add stretch version load_image for RTDETR mosaic | ||||
|     def load_image(self, i, rect_mode=False): | ||||
|         """Loads 1 image from dataset index 'i', returns (im, resized hw).""" | ||||
|         return super().load_image(i=i, rect_mode=rect_mode) | ||||
| @ -46,7 +51,11 @@ class RTDETRDataset(YOLODataset): | ||||
| 
 | ||||
| class RTDETRValidator(DetectionValidator): | ||||
|     """ | ||||
|     A class extending the DetectionValidator class for validation based on an RT-DETR detection model. | ||||
|     RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for | ||||
|     the RT-DETR (Real-Time DETR) object detection model. | ||||
| 
 | ||||
|     The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for | ||||
|     post-processing, and updates evaluation metrics accordingly. | ||||
| 
 | ||||
|     Example: | ||||
|         ```python | ||||
| @ -56,6 +65,9 @@ class RTDETRValidator(DetectionValidator): | ||||
|         validator = RTDETRValidator(args=args) | ||||
|         validator() | ||||
|         ``` | ||||
| 
 | ||||
|     Note: | ||||
|         For further details on the attributes and methods, refer to the parent DetectionValidator class. | ||||
|     """ | ||||
| 
 | ||||
|     def build_dataset(self, img_path, mode='val', batch=None): | ||||
|  | ||||
| @ -10,6 +10,21 @@ from ultralytics.nn.modules import LayerNorm2d | ||||
| 
 | ||||
| 
 | ||||
| class MaskDecoder(nn.Module): | ||||
|     """ | ||||
|     Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict | ||||
|     masks given image and prompt embeddings. | ||||
| 
 | ||||
|     Attributes: | ||||
|         transformer_dim (int): Channel dimension for the transformer module. | ||||
|         transformer (nn.Module): The transformer module used for mask prediction. | ||||
|         num_multimask_outputs (int): Number of masks to predict for disambiguating masks. | ||||
|         iou_token (nn.Embedding): Embedding for the IoU token. | ||||
|         num_mask_tokens (int): Number of mask tokens. | ||||
|         mask_tokens (nn.Embedding): Embedding for the mask tokens. | ||||
|         output_upscaling (nn.Sequential): Neural network sequence for upscaling the output. | ||||
|         output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks. | ||||
|         iou_prediction_head (nn.Module): MLP for predicting mask quality. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -136,7 +151,7 @@ class MaskDecoder(nn.Module): | ||||
| 
 | ||||
| class MLP(nn.Module): | ||||
|     """ | ||||
|     Lightly adapted from | ||||
|     MLP (Multi-Layer Perceptron) model lightly adapted from | ||||
|     https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py | ||||
|     """ | ||||
| 
 | ||||
| @ -148,6 +163,16 @@ class MLP(nn.Module): | ||||
|         num_layers: int, | ||||
|         sigmoid_output: bool = False, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Initializes the MLP (Multi-Layer Perceptron) model. | ||||
| 
 | ||||
|         Args: | ||||
|             input_dim (int): The dimensionality of the input features. | ||||
|             hidden_dim (int): The dimensionality of the hidden layers. | ||||
|             output_dim (int): The dimensionality of the output layer. | ||||
|             num_layers (int): The number of hidden layers. | ||||
|             sigmoid_output (bool, optional): Whether to apply a sigmoid activation to the output layer. Defaults to False. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.num_layers = num_layers | ||||
|         h = [hidden_dim] * (num_layers - 1) | ||||
|  | ||||
| @ -12,6 +12,18 @@ from ultralytics.nn.modules import LayerNorm2d, MLPBlock | ||||
| 
 | ||||
| # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa | ||||
| class ImageEncoderViT(nn.Module): | ||||
|     """ | ||||
|     An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The | ||||
|     encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks. | ||||
|     The encoded patches are then processed through a neck to generate the final encoded representation. | ||||
| 
 | ||||
|     Attributes: | ||||
|         img_size (int): Dimension of input images, assumed to be square. | ||||
|         patch_embed (PatchEmbed): Module for patch embedding. | ||||
|         pos_embed (nn.Parameter, optional): Absolute positional embedding for patches. | ||||
|         blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings. | ||||
|         neck (nn.Sequential): Neck module to further process the output. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|             self, | ||||
| @ -112,6 +124,22 @@ class ImageEncoderViT(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class PromptEncoder(nn.Module): | ||||
|     """ | ||||
|     Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder | ||||
|     produces both sparse and dense embeddings for the input prompts. | ||||
| 
 | ||||
|     Attributes: | ||||
|         embed_dim (int): Dimension of the embeddings. | ||||
|         input_image_size (Tuple[int, int]): Size of the input image as (H, W). | ||||
|         image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). | ||||
|         pe_layer (PositionEmbeddingRandom): Module for random position embedding. | ||||
|         num_point_embeddings (int): Number of point embeddings for different types of points. | ||||
|         point_embeddings (nn.ModuleList): List of point embeddings. | ||||
|         not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label. | ||||
|         mask_input_size (Tuple[int, int]): Size of the input mask. | ||||
|         mask_downscaling (nn.Sequential): Neural network for downscaling the mask. | ||||
|         no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|  | ||||
| @ -16,6 +16,20 @@ from .encoders import ImageEncoderViT, PromptEncoder | ||||
| 
 | ||||
| 
 | ||||
| class Sam(nn.Module): | ||||
|     """ | ||||
|     Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image | ||||
|     embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask | ||||
|     decoder to predict object masks. | ||||
| 
 | ||||
|     Attributes: | ||||
|         mask_threshold (float): Threshold value for mask prediction. | ||||
|         image_format (str): Format of the input image, default is 'RGB'. | ||||
|         image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings. | ||||
|         prompt_encoder (PromptEncoder): Encodes various types of input prompts. | ||||
|         mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings. | ||||
|         pixel_mean (List[float]): Mean pixel values for image normalization. | ||||
|         pixel_std (List[float]): Standard deviation values for image normalization. | ||||
|     """ | ||||
|     mask_threshold: float = 0.0 | ||||
|     image_format: str = 'RGB' | ||||
| 
 | ||||
| @ -28,18 +42,19 @@ class Sam(nn.Module): | ||||
|         pixel_std: List[float] = (58.395, 57.12, 57.375) | ||||
|     ) -> None: | ||||
|         """ | ||||
|         SAM predicts object masks from an image and input prompts. | ||||
|         Initialize the Sam class to predict object masks from an image and input prompts. | ||||
| 
 | ||||
|         Note: | ||||
|             All forward() operations moved to SAMPredictor. | ||||
| 
 | ||||
|         Args: | ||||
|           image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for | ||||
|             efficient mask prediction. | ||||
|           prompt_encoder (PromptEncoder): Encodes various types of input prompts. | ||||
|           mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. | ||||
|           pixel_mean (list(float)): Mean values for normalizing pixels in the input image. | ||||
|           pixel_std (list(float)): Std values for normalizing pixels in the input image. | ||||
|             image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. | ||||
|             prompt_encoder (PromptEncoder): Encodes various types of input prompts. | ||||
|             mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. | ||||
|             pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to | ||||
|                 (123.675, 116.28, 103.53). | ||||
|             pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to | ||||
|                 (58.395, 57.12, 57.375). | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.image_encoder = image_encoder | ||||
|  | ||||
| @ -21,6 +21,7 @@ from ultralytics.utils.instance import to_2tuple | ||||
| 
 | ||||
| 
 | ||||
| class Conv2d_BN(torch.nn.Sequential): | ||||
|     """A sequential container that performs 2D convolution followed by batch normalization.""" | ||||
| 
 | ||||
|     def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): | ||||
|         """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and | ||||
| @ -35,6 +36,7 @@ class Conv2d_BN(torch.nn.Sequential): | ||||
| 
 | ||||
| 
 | ||||
| class PatchEmbed(nn.Module): | ||||
|     """Embeds images into patches and projects them into a specified embedding dimension.""" | ||||
| 
 | ||||
|     def __init__(self, in_chans, embed_dim, resolution, activation): | ||||
|         """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation | ||||
| @ -59,6 +61,7 @@ class PatchEmbed(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class MBConv(nn.Module): | ||||
|     """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.""" | ||||
| 
 | ||||
|     def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): | ||||
|         """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation | ||||
| @ -96,6 +99,7 @@ class MBConv(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class PatchMerging(nn.Module): | ||||
|     """Merges neighboring patches in the feature map and projects to a new dimension.""" | ||||
| 
 | ||||
|     def __init__(self, input_resolution, dim, out_dim, activation): | ||||
|         """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other | ||||
| @ -130,6 +134,11 @@ class PatchMerging(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class ConvLayer(nn.Module): | ||||
|     """ | ||||
|     Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv). | ||||
| 
 | ||||
|     Optionally applies downsample operations to the output, and provides support for gradient checkpointing. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -143,6 +152,20 @@ class ConvLayer(nn.Module): | ||||
|         out_dim=None, | ||||
|         conv_expand_ratio=4., | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the ConvLayer with the given dimensions and settings. | ||||
| 
 | ||||
|         Args: | ||||
|             dim (int): The dimensionality of the input and output. | ||||
|             input_resolution (Tuple[int, int]): The resolution of the input image. | ||||
|             depth (int): The number of MBConv layers in the block. | ||||
|             activation (Callable): Activation function applied after each convolution. | ||||
|             drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv. | ||||
|             downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling. | ||||
|             use_checkpoint (bool): Whether to use gradient checkpointing to save memory. | ||||
|             out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`. | ||||
|             conv_expand_ratio (float): Expansion ratio for the MBConv layers. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.dim = dim | ||||
|         self.input_resolution = input_resolution | ||||
| @ -171,6 +194,11 @@ class ConvLayer(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class Mlp(nn.Module): | ||||
|     """ | ||||
|     Multi-layer Perceptron (MLP) for transformer architectures. | ||||
| 
 | ||||
|     This layer takes an input with in_features, applies layer normalization and two fully-connected layers. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||||
|         """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc.""" | ||||
| @ -194,6 +222,14 @@ class Mlp(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class Attention(torch.nn.Module): | ||||
|     """ | ||||
|     Multi-head attention module with support for spatial awareness, applying attention biases based on spatial | ||||
|     resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution | ||||
|     grid. | ||||
| 
 | ||||
|     Attributes: | ||||
|         ab (Tensor, optional): Cached attention biases for inference, deleted during training. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|             self, | ||||
| @ -203,8 +239,21 @@ class Attention(torch.nn.Module): | ||||
|             attn_ratio=4, | ||||
|             resolution=(14, 14), | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the Attention module. | ||||
| 
 | ||||
|         Args: | ||||
|             dim (int): The dimensionality of the input and output. | ||||
|             key_dim (int): The dimensionality of the keys and queries. | ||||
|             num_heads (int, optional): Number of attention heads. Default is 8. | ||||
|             attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4. | ||||
|             resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14). | ||||
| 
 | ||||
|         Raises: | ||||
|             AssertionError: If `resolution` is not a tuple of length 2. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         # (h, w) | ||||
| 
 | ||||
|         assert isinstance(resolution, tuple) and len(resolution) == 2 | ||||
|         self.num_heads = num_heads | ||||
|         self.scale = key_dim ** -0.5 | ||||
| @ -241,8 +290,9 @@ class Attention(torch.nn.Module): | ||||
|         else: | ||||
|             self.ab = self.attention_biases[:, self.attention_bias_idxs] | ||||
| 
 | ||||
|     def forward(self, x):  # x (B,N,C) | ||||
|         B, N, _ = x.shape | ||||
|     def forward(self, x):  # x | ||||
|         """Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values.""" | ||||
|         B, N, _ = x.shape  # B, N, C | ||||
| 
 | ||||
|         # Normalization | ||||
|         x = self.norm(x) | ||||
| @ -264,20 +314,7 @@ class Attention(torch.nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class TinyViTBlock(nn.Module): | ||||
|     """ | ||||
|     TinyViT Block. | ||||
| 
 | ||||
|     Args: | ||||
|         dim (int): Number of input channels. | ||||
|         input_resolution (tuple[int, int]): Input resolution. | ||||
|         num_heads (int): Number of attention heads. | ||||
|         window_size (int): Window size. | ||||
|         mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | ||||
|         drop (float, optional): Dropout rate. Default: 0.0 | ||||
|         drop_path (float, optional): Stochastic depth rate. Default: 0.0 | ||||
|         local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3 | ||||
|         activation (torch.nn): the activation function. Default: nn.GELU | ||||
|     """ | ||||
|     """TinyViT Block that applies self-attention and a local convolution to the input.""" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -291,6 +328,24 @@ class TinyViTBlock(nn.Module): | ||||
|         local_conv_size=3, | ||||
|         activation=nn.GELU, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the TinyViTBlock. | ||||
| 
 | ||||
|         Args: | ||||
|             dim (int): The dimensionality of the input and output. | ||||
|             input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. | ||||
|             num_heads (int): Number of attention heads. | ||||
|             window_size (int, optional): Window size for attention. Default is 7. | ||||
|             mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. | ||||
|             drop (float, optional): Dropout rate. Default is 0. | ||||
|             drop_path (float, optional): Stochastic depth rate. Default is 0. | ||||
|             local_conv_size (int, optional): The kernel size of the local convolution. Default is 3. | ||||
|             activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. | ||||
| 
 | ||||
|         Raises: | ||||
|             AssertionError: If `window_size` is not greater than 0. | ||||
|             AssertionError: If `dim` is not divisible by `num_heads`. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.dim = dim | ||||
|         self.input_resolution = input_resolution | ||||
| @ -367,24 +422,7 @@ class TinyViTBlock(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class BasicLayer(nn.Module): | ||||
|     """ | ||||
|     A basic TinyViT layer for one stage. | ||||
| 
 | ||||
|     Args: | ||||
|         dim (int): Number of input channels. | ||||
|         input_resolution (tuple[int]): Input resolution. | ||||
|         depth (int): Number of blocks. | ||||
|         num_heads (int): Number of attention heads. | ||||
|         window_size (int): Local window size. | ||||
|         mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | ||||
|         drop (float, optional): Dropout rate. Default: 0.0 | ||||
|         drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | ||||
|         downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | ||||
|         use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | ||||
|         local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3 | ||||
|         activation (torch.nn): the activation function. Default: nn.GELU | ||||
|         out_dim (int | optional): the output dimension of the layer. Default: None | ||||
|     """ | ||||
|     """A basic TinyViT layer for one stage in a TinyViT architecture.""" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -402,6 +440,27 @@ class BasicLayer(nn.Module): | ||||
|         activation=nn.GELU, | ||||
|         out_dim=None, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the BasicLayer. | ||||
| 
 | ||||
|         Args: | ||||
|             dim (int): The dimensionality of the input and output. | ||||
|             input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. | ||||
|             depth (int): Number of TinyViT blocks. | ||||
|             num_heads (int): Number of attention heads. | ||||
|             window_size (int): Local window size. | ||||
|             mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. | ||||
|             drop (float, optional): Dropout rate. Default is 0. | ||||
|             drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0. | ||||
|             downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None. | ||||
|             use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False. | ||||
|             local_conv_size (int, optional): Kernel size of the local convolution. Default is 3. | ||||
|             activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. | ||||
|             out_dim (int | None, optional): The output dimension of the layer. Default is None. | ||||
| 
 | ||||
|         Raises: | ||||
|             ValueError: If `drop_path` is a list of float but its length doesn't match `depth`. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.dim = dim | ||||
|         self.input_resolution = input_resolution | ||||
| @ -456,6 +515,30 @@ class LayerNorm2d(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class TinyViT(nn.Module): | ||||
|     """ | ||||
|     The TinyViT architecture for vision tasks. | ||||
| 
 | ||||
|     Attributes: | ||||
|         img_size (int): Input image size. | ||||
|         in_chans (int): Number of input channels. | ||||
|         num_classes (int): Number of classification classes. | ||||
|         embed_dims (List[int]): List of embedding dimensions for each layer. | ||||
|         depths (List[int]): List of depths for each layer. | ||||
|         num_heads (List[int]): List of number of attention heads for each layer. | ||||
|         window_sizes (List[int]): List of window sizes for each layer. | ||||
|         mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. | ||||
|         drop_rate (float): Dropout rate for drop layers. | ||||
|         drop_path_rate (float): Drop path rate for stochastic depth. | ||||
|         use_checkpoint (bool): Use checkpointing for efficient memory usage. | ||||
|         mbconv_expand_ratio (float): Expansion ratio for MBConv layer. | ||||
|         local_conv_size (int): Local convolution kernel size. | ||||
|         layer_lr_decay (float): Layer-wise learning rate decay. | ||||
| 
 | ||||
|     Note: | ||||
|         This implementation is generalized to accept a list of depths, attention heads, | ||||
|         embedding dimensions and window sizes, which allows you to create a | ||||
|         "stack" of TinyViT models of varying configurations. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -474,6 +557,25 @@ class TinyViT(nn.Module): | ||||
|         local_conv_size=3, | ||||
|         layer_lr_decay=1.0, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes the TinyViT model. | ||||
| 
 | ||||
|         Args: | ||||
|             img_size (int, optional): The input image size. Defaults to 224. | ||||
|             in_chans (int, optional): Number of input channels. Defaults to 3. | ||||
|             num_classes (int, optional): Number of classification classes. Defaults to 1000. | ||||
|             embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768]. | ||||
|             depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2]. | ||||
|             num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24]. | ||||
|             window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7]. | ||||
|             mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4. | ||||
|             drop_rate (float, optional): Dropout rate. Defaults to 0. | ||||
|             drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1. | ||||
|             use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False. | ||||
|             mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0. | ||||
|             local_conv_size (int, optional): Local convolution kernel size. Defaults to 3. | ||||
|             layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.img_size = img_size | ||||
|         self.num_classes = num_classes | ||||
|  | ||||
| @ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock | ||||
| 
 | ||||
| 
 | ||||
| class TwoWayTransformer(nn.Module): | ||||
|     """ | ||||
|     A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class | ||||
|     serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding | ||||
|     is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud | ||||
|     processing. | ||||
| 
 | ||||
|     Attributes: | ||||
|         depth (int): The number of layers in the transformer. | ||||
|         embedding_dim (int): The channel dimension for the input embeddings. | ||||
|         num_heads (int): The number of heads for multihead attention. | ||||
|         mlp_dim (int): The internal channel dimension for the MLP block. | ||||
|         layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer. | ||||
|         final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image. | ||||
|         norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -98,6 +113,23 @@ class TwoWayTransformer(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class TwoWayAttentionBlock(nn.Module): | ||||
|     """ | ||||
|     An attention block that performs both self-attention and cross-attention in two directions: queries to keys and | ||||
|     keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention | ||||
|     of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to | ||||
|     sparse inputs. | ||||
| 
 | ||||
|     Attributes: | ||||
|         self_attn (Attention): The self-attention layer for the queries. | ||||
|         norm1 (nn.LayerNorm): Layer normalization following the first attention block. | ||||
|         cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. | ||||
|         norm2 (nn.LayerNorm): Layer normalization following the second attention block. | ||||
|         mlp (MLPBlock): MLP block that transforms the query embeddings. | ||||
|         norm3 (nn.LayerNorm): Layer normalization following the MLP block. | ||||
|         norm4 (nn.LayerNorm): Layer normalization following the third attention block. | ||||
|         cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. | ||||
|         skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
| @ -180,6 +212,17 @@ class Attention(nn.Module): | ||||
|         num_heads: int, | ||||
|         downsample_rate: int = 1, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Initializes the Attention model with the given dimensions and settings. | ||||
| 
 | ||||
|         Args: | ||||
|             embedding_dim (int): The dimensionality of the input embeddings. | ||||
|             num_heads (int): The number of attention heads. | ||||
|             downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1. | ||||
| 
 | ||||
|         Raises: | ||||
|             AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate). | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.embedding_dim = embedding_dim | ||||
|         self.internal_dim = embedding_dim // downsample_rate | ||||
| @ -191,13 +234,15 @@ class Attention(nn.Module): | ||||
|         self.v_proj = nn.Linear(embedding_dim, self.internal_dim) | ||||
|         self.out_proj = nn.Linear(self.internal_dim, embedding_dim) | ||||
| 
 | ||||
|     def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: | ||||
|     @staticmethod | ||||
|     def _separate_heads(x: Tensor, num_heads: int) -> Tensor: | ||||
|         """Separate the input tensor into the specified number of attention heads.""" | ||||
|         b, n, c = x.shape | ||||
|         x = x.reshape(b, n, num_heads, c // num_heads) | ||||
|         return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head | ||||
| 
 | ||||
|     def _recombine_heads(self, x: Tensor) -> Tensor: | ||||
|     @staticmethod | ||||
|     def _recombine_heads(x: Tensor) -> Tensor: | ||||
|         """Recombine the separated attention heads into a single tensor.""" | ||||
|         b, n_heads, n_tokens, c_per_head = x.shape | ||||
|         x = x.transpose(1, 2) | ||||
|  | ||||
| @ -17,6 +17,24 @@ from .build import build_sam | ||||
| 
 | ||||
| 
 | ||||
| class Predictor(BasePredictor): | ||||
|     """ | ||||
|     A prediction class for segmentation tasks, extending the BasePredictor. | ||||
| 
 | ||||
|     This class serves as an interface for model inference for segmentation tasks. | ||||
|     It can preprocess input images, perform inference, and postprocess the output. | ||||
|     It also supports handling various types of input prompts including bounding boxes, | ||||
|     points, and low-resolution masks for better prediction results. | ||||
| 
 | ||||
|     Attributes: | ||||
|         cfg (dict): Configuration dictionary. | ||||
|         overrides (dict): Dictionary of overriding values. | ||||
|         _callbacks (dict): Dictionary of callback functions. | ||||
|         args (namespace): Argument namespace. | ||||
|         im (torch.Tensor): Preprocessed image for current prediction. | ||||
|         features (torch.Tensor): Image features. | ||||
|         prompts (dict): Dictionary of prompts like bboxes, points, masks. | ||||
|         segment_all (bool): Whether to perform segmentation on all objects or not. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||
|         """Initializes the Predictor class with default or provided configuration, overrides, and callbacks.""" | ||||
|  | ||||
| @ -11,6 +11,24 @@ from .ops import HungarianMatcher | ||||
| 
 | ||||
| 
 | ||||
| class DETRLoss(nn.Module): | ||||
|     """ | ||||
|     DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the | ||||
|     DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary | ||||
|     losses. | ||||
| 
 | ||||
|     Attributes: | ||||
|         nc (int): The number of classes. | ||||
|         loss_gain (dict): Coefficients for different loss components. | ||||
|         aux_loss (bool): Whether to compute auxiliary losses. | ||||
|         use_fl (bool): Use FocalLoss or not. | ||||
|         use_vfl (bool): Use VarifocalLoss or not. | ||||
|         use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch. | ||||
|         uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True. | ||||
|         matcher (HungarianMatcher): Object to compute matching cost and indices. | ||||
|         fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None. | ||||
|         vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None. | ||||
|         device (torch.device): Device on which tensors are stored. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, | ||||
|                  nc=80, | ||||
|  | ||||
| @ -37,7 +37,12 @@ class DFL(nn.Module): | ||||
| class Proto(nn.Module): | ||||
|     """YOLOv8 mask Proto module for segmentation models.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c_=256, c2=32):  # ch_in, number of protos, number of masks | ||||
|     def __init__(self, c1, c_=256, c2=32): | ||||
|         """ | ||||
|         Initializes the YOLOv8 mask Proto module with specified number of protos and masks. | ||||
| 
 | ||||
|         Input arguments are ch_in, number of protos, number of masks. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.cv1 = Conv(c1, c_, k=3) | ||||
|         self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)  # nn.Upsample(scale_factor=2, mode='nearest') | ||||
| @ -124,7 +129,12 @@ class SPP(nn.Module): | ||||
| class SPPF(nn.Module): | ||||
|     """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13)) | ||||
|     def __init__(self, c1, c2, k=5): | ||||
|         """ | ||||
|         Initializes the SPPF layer with given input/output channels and kernel size. | ||||
| 
 | ||||
|         This module is equivalent to SPP(k=(5, 9, 13)). | ||||
|         """ | ||||
|         super().__init__() | ||||
|         c_ = c1 // 2  # hidden channels | ||||
|         self.cv1 = Conv(c1, c_, 1, 1) | ||||
| @ -142,7 +152,8 @@ class SPPF(nn.Module): | ||||
| class C1(nn.Module): | ||||
|     """CSP Bottleneck with 1 convolution.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, n=1):  # ch_in, ch_out, number | ||||
|     def __init__(self, c1, c2, n=1): | ||||
|         """Initializes the CSP Bottleneck with configurations for 1 convolution with arguments ch_in, ch_out, number.""" | ||||
|         super().__init__() | ||||
|         self.cv1 = Conv(c1, c2, 1, 1) | ||||
|         self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) | ||||
| @ -156,7 +167,10 @@ class C1(nn.Module): | ||||
| class C2(nn.Module): | ||||
|     """CSP Bottleneck with 2 convolutions.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): | ||||
|         """Initializes the CSP Bottleneck with 2 convolutions module with arguments ch_in, ch_out, number, shortcut, | ||||
|         groups, expansion. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.c = int(c2 * e)  # hidden channels | ||||
|         self.cv1 = Conv(c1, 2 * self.c, 1, 1) | ||||
| @ -173,7 +187,10 @@ class C2(nn.Module): | ||||
| class C2f(nn.Module): | ||||
|     """Faster Implementation of CSP Bottleneck with 2 convolutions.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion | ||||
|     def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): | ||||
|         """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, | ||||
|         expansion. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.c = int(c2 * e)  # hidden channels | ||||
|         self.cv1 = Conv(c1, 2 * self.c, 1, 1) | ||||
| @ -196,7 +213,8 @@ class C2f(nn.Module): | ||||
| class C3(nn.Module): | ||||
|     """CSP Bottleneck with 3 convolutions.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): | ||||
|         """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values.""" | ||||
|         super().__init__() | ||||
|         c_ = int(c2 * e)  # hidden channels | ||||
|         self.cv1 = Conv(c1, c_, 1, 1) | ||||
| @ -259,7 +277,8 @@ class C3Ghost(C3): | ||||
| class GhostBottleneck(nn.Module): | ||||
|     """Ghost Bottleneck https://github.com/huawei-noah/ghostnet.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=3, s=1):  # ch_in, ch_out, kernel, stride | ||||
|     def __init__(self, c1, c2, k=3, s=1): | ||||
|         """Initializes GhostBottleneck module with arguments ch_in, ch_out, kernel, stride.""" | ||||
|         super().__init__() | ||||
|         c_ = c2 // 2 | ||||
|         self.conv = nn.Sequential( | ||||
| @ -277,7 +296,10 @@ class GhostBottleneck(nn.Module): | ||||
| class Bottleneck(nn.Module): | ||||
|     """Standard bottleneck.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):  # ch_in, ch_out, shortcut, groups, kernels, expand | ||||
|     def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): | ||||
|         """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and | ||||
|         expansion. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         c_ = int(c2 * e)  # hidden channels | ||||
|         self.cv1 = Conv(c1, c_, k[0], 1) | ||||
| @ -292,7 +314,8 @@ class Bottleneck(nn.Module): | ||||
| class BottleneckCSP(nn.Module): | ||||
|     """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion | ||||
|     def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): | ||||
|         """Initializes the CSP Bottleneck given arguments for ch_in, ch_out, number, shortcut, groups, expansion.""" | ||||
|         super().__init__() | ||||
|         c_ = int(c2 * e)  # hidden channels | ||||
|         self.cv1 = Conv(c1, c_, 1, 1) | ||||
|  | ||||
| @ -88,6 +88,7 @@ class DWConv(Conv): | ||||
|     """Depth-wise convolution.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, d=1, act=True):  # ch_in, ch_out, kernel, stride, dilation, activation | ||||
|         """Initialize Depth-wise convolution with given parameters.""" | ||||
|         super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) | ||||
| 
 | ||||
| 
 | ||||
| @ -95,6 +96,7 @@ class DWConvTranspose2d(nn.ConvTranspose2d): | ||||
|     """Depth-wise transpose convolution.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):  # ch_in, ch_out, kernel, stride, padding, padding_out | ||||
|         """Initialize DWConvTranspose2d class with given parameters.""" | ||||
|         super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) | ||||
| 
 | ||||
| 
 | ||||
| @ -121,12 +123,18 @@ class ConvTranspose(nn.Module): | ||||
| class Focus(nn.Module): | ||||
|     """Focus wh information into c-space.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups | ||||
|     def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): | ||||
|         """Initializes Focus object with user defined channel, convolution, padding, group and activation values.""" | ||||
|         super().__init__() | ||||
|         self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act) | ||||
|         # self.contract = Contract(gain=2) | ||||
| 
 | ||||
|     def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2) | ||||
|     def forward(self, x): | ||||
|         """ | ||||
|         Applies convolution to concatenated tensor and returns the output. | ||||
| 
 | ||||
|         Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2). | ||||
|         """ | ||||
|         return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) | ||||
|         # return self.conv(self.contract(x)) | ||||
| 
 | ||||
| @ -134,7 +142,10 @@ class Focus(nn.Module): | ||||
| class GhostConv(nn.Module): | ||||
|     """Ghost Convolution https://github.com/huawei-noah/ghostnet.""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups | ||||
|     def __init__(self, c1, c2, k=1, s=1, g=1, act=True): | ||||
|         """Initializes the GhostConv object with input channels, output channels, kernel size, stride, groups and | ||||
|         activation. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         c_ = c2 // 2  # hidden channels | ||||
|         self.cv1 = Conv(c1, c_, k, s, None, g, act=act) | ||||
| @ -280,7 +291,8 @@ class SpatialAttention(nn.Module): | ||||
| class CBAM(nn.Module): | ||||
|     """Convolutional Block Attention Module.""" | ||||
| 
 | ||||
|     def __init__(self, c1, kernel_size=7):  # ch_in, kernels | ||||
|     def __init__(self, c1, kernel_size=7): | ||||
|         """Initialize CBAM with given input channel (c1) and kernel size.""" | ||||
|         super().__init__() | ||||
|         self.channel_attention = ChannelAttention(c1) | ||||
|         self.spatial_attention = SpatialAttention(kernel_size) | ||||
|  | ||||
| @ -25,7 +25,8 @@ class Detect(nn.Module): | ||||
|     anchors = torch.empty(0)  # init | ||||
|     strides = torch.empty(0)  # init | ||||
| 
 | ||||
|     def __init__(self, nc=80, ch=()):  # detection layer | ||||
|     def __init__(self, nc=80, ch=()): | ||||
|         """Initializes the YOLOv8 detection layer with specified number of classes and channels.""" | ||||
|         super().__init__() | ||||
|         self.nc = nc  # number of classes | ||||
|         self.nl = len(ch)  # number of detection layers | ||||
| @ -149,7 +150,10 @@ class Pose(Detect): | ||||
| class Classify(nn.Module): | ||||
|     """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" | ||||
| 
 | ||||
|     def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups | ||||
|     def __init__(self, c1, c2, k=1, s=1, p=None, g=1): | ||||
|         """Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride, | ||||
|         padding, and groups. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         c_ = 1280  # efficientnet_b0 size | ||||
|         self.conv = Conv(c1, c_, k, s, p, g) | ||||
| @ -166,6 +170,13 @@ class Classify(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class RTDETRDecoder(nn.Module): | ||||
|     """ | ||||
|     Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection. | ||||
| 
 | ||||
|     This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes | ||||
|     and class labels for objects in an image. It integrates features from multiple layers and runs through a series of | ||||
|     Transformer decoder layers to output the final predictions. | ||||
|     """ | ||||
|     export = False  # export mode | ||||
| 
 | ||||
|     def __init__( | ||||
| @ -186,6 +197,26 @@ class RTDETRDecoder(nn.Module): | ||||
|             label_noise_ratio=0.5, | ||||
|             box_noise_scale=1.0, | ||||
|             learnt_init_query=False): | ||||
|         """ | ||||
|         Initializes the RTDETRDecoder module with the given parameters. | ||||
| 
 | ||||
|         Args: | ||||
|             nc (int): Number of classes. Default is 80. | ||||
|             ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048). | ||||
|             hd (int): Dimension of hidden layers. Default is 256. | ||||
|             nq (int): Number of query points. Default is 300. | ||||
|             ndp (int): Number of decoder points. Default is 4. | ||||
|             nh (int): Number of heads in multi-head attention. Default is 8. | ||||
|             ndl (int): Number of decoder layers. Default is 6. | ||||
|             d_ffn (int): Dimension of the feed-forward networks. Default is 1024. | ||||
|             dropout (float): Dropout rate. Default is 0. | ||||
|             act (nn.Module): Activation function. Default is nn.ReLU. | ||||
|             eval_idx (int): Evaluation index. Default is -1. | ||||
|             nd (int): Number of denoising. Default is 100. | ||||
|             label_noise_ratio (float): Label noise ratio. Default is 0.5. | ||||
|             box_noise_scale (float): Box noise scale. Default is 1.0. | ||||
|             learnt_init_query (bool): Whether to learn initial query embeddings. Default is False. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.hidden_dim = hd | ||||
|         self.nhead = nh | ||||
|  | ||||
| @ -375,9 +375,9 @@ class RTDETRDetectionModel(DetectionModel): | ||||
|     """ | ||||
|     RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. | ||||
| 
 | ||||
|     This class is responsible for constructing the RTDETR architecture, defining loss functions, and | ||||
|     facilitating both the training and inference processes. RTDETR is an object detection and tracking model | ||||
|     that extends from the DetectionModel base class. | ||||
|     This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both | ||||
|     the training and inference processes. RTDETR is an object detection and tracking model that extends from the | ||||
|     DetectionModel base class. | ||||
| 
 | ||||
|     Attributes: | ||||
|         cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'. | ||||
| @ -418,7 +418,7 @@ class RTDETRDetectionModel(DetectionModel): | ||||
|             preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None. | ||||
| 
 | ||||
|         Returns: | ||||
|             tuple: A tuple containing the total loss and main three losses in a tensor. | ||||
|             (tuple): A tuple containing the total loss and main three losses in a tensor. | ||||
|         """ | ||||
|         if not hasattr(self, 'criterion'): | ||||
|             self.criterion = self.init_criterion() | ||||
| @ -466,7 +466,7 @@ class RTDETRDetectionModel(DetectionModel): | ||||
|             augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. | ||||
| 
 | ||||
|         Returns: | ||||
|             torch.Tensor: Model's output tensor. | ||||
|             (torch.Tensor): Model's output tensor. | ||||
|         """ | ||||
|         y, dt = [], []  # outputs | ||||
|         for m in self.model[:-1]:  # except the head part | ||||
|  | ||||
| @ -184,6 +184,19 @@ class ProfileModels: | ||||
|                  half=True, | ||||
|                  trt=True, | ||||
|                  device=None): | ||||
|         """ | ||||
|         Initialize the ProfileModels class for profiling models. | ||||
| 
 | ||||
|         Args: | ||||
|             paths (list): List of paths of the models to be profiled. | ||||
|             num_timed_runs (int, optional): Number of timed runs for the profiling. Default is 100. | ||||
|             num_warmup_runs (int, optional): Number of warmup runs before the actual profiling starts. Default is 10. | ||||
|             min_time (float, optional): Minimum time in seconds for profiling a model. Default is 60. | ||||
|             imgsz (int, optional): Size of the image used during profiling. Default is 640. | ||||
|             half (bool, optional): Flag to indicate whether to use half-precision floating point for profiling. Default is True. | ||||
|             trt (bool, optional): Flag to indicate whether to profile using TensorRT. Default is True. | ||||
|             device (torch.device, optional): Device used for profiling. If None, it is determined automatically. Default is None. | ||||
|         """ | ||||
|         self.paths = paths | ||||
|         self.num_timed_runs = num_timed_runs | ||||
|         self.num_warmup_runs = num_warmup_runs | ||||
|  | ||||
| @ -4,6 +4,18 @@ from ultralytics.utils import emojis | ||||
| 
 | ||||
| 
 | ||||
| class HUBModelError(Exception): | ||||
|     """ | ||||
|     Custom exception class for handling errors related to model fetching in Ultralytics YOLO. | ||||
| 
 | ||||
|     This exception is raised when a requested model is not found or cannot be retrieved. | ||||
|     The message is also processed to include emojis for better user experience. | ||||
| 
 | ||||
|     Attributes: | ||||
|         message (str): The error message displayed when the exception is raised. | ||||
| 
 | ||||
|     Note: | ||||
|         The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, message='Model not found. Please check model URL and try again.'): | ||||
|         """Create an exception for when a model is not found.""" | ||||
|  | ||||
| @ -33,9 +33,17 @@ __all__ = 'Bboxes',  # tuple or list | ||||
| 
 | ||||
| class Bboxes: | ||||
|     """ | ||||
|     Bounding Boxes class. | ||||
|     A class for handling bounding boxes. | ||||
| 
 | ||||
|     Only numpy variables are supported. | ||||
|     The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'. | ||||
|     Bounding box data should be provided in numpy arrays. | ||||
| 
 | ||||
|     Attributes: | ||||
|         bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array. | ||||
|         format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh'). | ||||
| 
 | ||||
|     Note: | ||||
|         This class does not handle normalization or denormalization of bounding boxes. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, bboxes, format='xyxy') -> None: | ||||
| @ -166,6 +174,36 @@ class Bboxes: | ||||
| 
 | ||||
| 
 | ||||
| class Instances: | ||||
|     """ | ||||
|     Container for bounding boxes, segments, and keypoints of detected objects in an image. | ||||
| 
 | ||||
|     Attributes: | ||||
|         _bboxes (Bboxes): Internal object for handling bounding box operations. | ||||
|         keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None. | ||||
|         normalized (bool): Flag indicating whether the bounding box coordinates are normalized. | ||||
|         segments (ndarray): Segments array with shape [N, 1000, 2] after resampling. | ||||
| 
 | ||||
|     Args: | ||||
|         bboxes (ndarray): An array of bounding boxes with shape [N, 4]. | ||||
|         segments (list | ndarray, optional): A list or array of object segments. Default is None. | ||||
|         keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None. | ||||
|         bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'. | ||||
|         normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True. | ||||
| 
 | ||||
|     Examples: | ||||
|         ```python | ||||
|         # Create an Instances object | ||||
|         instances = Instances( | ||||
|             bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]), | ||||
|             segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])], | ||||
|             keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]) | ||||
|         ) | ||||
|         ``` | ||||
| 
 | ||||
|     Note: | ||||
|         The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument. | ||||
|         This class does not perform input validation, and it assumes the inputs are well-formed. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None: | ||||
|         """ | ||||
|  | ||||
| @ -59,6 +59,7 @@ class FocalLoss(nn.Module): | ||||
| 
 | ||||
| 
 | ||||
| class BboxLoss(nn.Module): | ||||
|     """Criterion class for computing training losses during training.""" | ||||
| 
 | ||||
|     def __init__(self, reg_max, use_dfl=False): | ||||
|         """Initialize the BboxLoss module with regularization maximum and DFL settings.""" | ||||
| @ -115,7 +116,7 @@ class v8DetectionLoss: | ||||
|     """Criterion class for computing training losses.""" | ||||
| 
 | ||||
|     def __init__(self, model):  # model must be de-paralleled | ||||
| 
 | ||||
|         """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function.""" | ||||
|         device = next(model.parameters()).device  # get model device | ||||
|         h = model.args  # hyperparameters | ||||
| 
 | ||||
| @ -211,6 +212,7 @@ class v8SegmentationLoss(v8DetectionLoss): | ||||
|     """Criterion class for computing training losses.""" | ||||
| 
 | ||||
|     def __init__(self, model):  # model must be de-paralleled | ||||
|         """Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument.""" | ||||
|         super().__init__(model) | ||||
|         self.overlap = model.args.overlap_mask | ||||
| 
 | ||||
| @ -375,6 +377,7 @@ class v8PoseLoss(v8DetectionLoss): | ||||
|     """Criterion class for computing training losses.""" | ||||
| 
 | ||||
|     def __init__(self, model):  # model must be de-paralleled | ||||
|         """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance.""" | ||||
|         super().__init__(model) | ||||
|         self.kpt_shape = model.model[-1].kpt_shape | ||||
|         self.bce_pose = nn.BCEWithLogitsLoss() | ||||
|  | ||||
| @ -166,8 +166,19 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7): | ||||
|     return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) | ||||
| 
 | ||||
| 
 | ||||
| def smooth_BCE(eps=0.1):  # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 | ||||
|     # return positive, negative label smoothing BCE targets | ||||
| def smooth_BCE(eps=0.1): | ||||
|     """ | ||||
|     Computes smoothed positive and negative Binary Cross-Entropy targets. | ||||
| 
 | ||||
|     This function calculates positive and negative label smoothing BCE targets based on a given epsilon value. | ||||
|     For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441. | ||||
| 
 | ||||
|     Args: | ||||
|         eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1. | ||||
| 
 | ||||
|     Returns: | ||||
|         (tuple): A tuple containing the positive and negative label smoothing BCE targets. | ||||
|     """ | ||||
|     return 1.0 - 0.5 * eps, 0.5 * eps | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -311,8 +311,10 @@ def initialize_weights(model): | ||||
|             m.inplace = True | ||||
| 
 | ||||
| 
 | ||||
| def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416) | ||||
|     # Scales img(bs,3,y,x) by ratio constrained to gs-multiple | ||||
| def scale_img(img, ratio=1.0, same_shape=False, gs=32): | ||||
|     """Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally | ||||
|     retaining the original shape. | ||||
|     """ | ||||
|     if ratio == 1.0: | ||||
|         return img | ||||
|     h, w = img.shape[2:] | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher