From 42744a1717f7e9ccaf2b3ab551332cd09ac24653 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 19 Feb 2024 14:24:30 +0100 Subject: [PATCH] Expand `Model` method type hinting (#8279) Signed-off-by: Glenn Jocher --- ultralytics/data/converter.py | 4 +- ultralytics/engine/model.py | 116 +++++++++++++++++++++--------- ultralytics/nn/tasks.py | 4 ++ ultralytics/trackers/utils/gmc.py | 20 +++--- ultralytics/utils/metrics.py | 8 +-- ultralytics/utils/plotting.py | 2 +- 6 files changed, 104 insertions(+), 50 deletions(-) diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py index 62be0b1f..eff4dac1 100644 --- a/ultralytics/data/converter.py +++ b/ultralytics/data/converter.py @@ -418,8 +418,8 @@ def min_index(arr1, arr2): Find a pair of indexes with the shortest distance between two arrays of 2D points. Args: - arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points. - arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points. + arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points. + arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points. Returns: (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively. diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 1a85913e..166c80fa 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -5,6 +5,10 @@ import sys from pathlib import Path from typing import Union +import PIL +import numpy as np +import torch + from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load @@ -78,7 +82,12 @@ class Model(nn.Module): NotImplementedError: If a specific model task or mode is not supported. """ - def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None: + def __init__( + self, + model: Union[str, Path] = "yolov8n.pt", + task: str = None, + verbose: bool = False, + ) -> None: """ Initializes a new instance of the YOLO model class. @@ -135,7 +144,12 @@ class Model(nn.Module): self.model_name = model - def __call__(self, source=None, stream=False, **kwargs): + def __call__( + self, + source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: """ An alias for the predict method, enabling the model instance to be callable. @@ -143,8 +157,9 @@ class Model(nn.Module): with the required arguments for prediction. Args: - source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions. - Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None. + source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making + predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays. + Defaults to None. stream (bool, optional): If True, treats the input source as a continuous stream for predictions. Defaults to False. **kwargs (dict): Additional keyword arguments for configuring the prediction process. @@ -163,7 +178,7 @@ class Model(nn.Module): return session if session.client.authenticated else None @staticmethod - def is_triton_model(model): + def is_triton_model(model: str) -> bool: """Is model a Triton Server URL string, i.e. :////""" from urllib.parse import urlsplit @@ -171,7 +186,7 @@ class Model(nn.Module): return url.netloc and url.path and url.scheme in {"http", "grpc"} @staticmethod - def is_hub_model(model): + def is_hub_model(model: str) -> bool: """Check if the provided model is a HUB model.""" return any( ( @@ -181,7 +196,7 @@ class Model(nn.Module): ) ) - def _new(self, cfg: str, task=None, model=None, verbose=False): + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: """ Initializes a new model and infers the task type from the model definitions. @@ -202,7 +217,7 @@ class Model(nn.Module): self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.task = self.task - def _load(self, weights: str, task=None): + def _load(self, weights: str, task=None) -> None: """ Initializes a new model and infers the task type from the model head. @@ -224,7 +239,7 @@ class Model(nn.Module): self.overrides["model"] = weights self.overrides["task"] = self.task - def _check_is_pytorch_model(self): + def _check_is_pytorch_model(self) -> None: """Raises TypeError is model is not a PyTorch model.""" pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" pt_module = isinstance(self.model, nn.Module) @@ -237,7 +252,7 @@ class Model(nn.Module): f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" ) - def reset_weights(self): + def reset_weights(self) -> "Model": """ Resets the model parameters to randomly initialized values, effectively discarding all training information. @@ -259,7 +274,7 @@ class Model(nn.Module): p.requires_grad = True return self - def load(self, weights="yolov8n.pt"): + def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model": """ Loads parameters from the specified weights file into the model. @@ -281,24 +296,22 @@ class Model(nn.Module): self.model.load(weights) return self - def save(self, filename="model.pt"): + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: """ Saves the current model state to a file. This method exports the model's checkpoint (ckpt) to the specified filename. Args: - filename (str): The name of the file to save the model to. Defaults to 'model.pt'. + filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'. Raises: AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() - import torch - torch.save(self.ckpt, filename) - def info(self, detailed=False, verbose=True): + def info(self, detailed: bool = False, verbose: bool = True): """ Logs or returns model information. @@ -330,7 +343,12 @@ class Model(nn.Module): self._check_is_pytorch_model() self.model.fuse() - def embed(self, source=None, stream=False, **kwargs): + def embed( + self, + source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs, + ) -> list: """ Generates image embeddings based on the provided source. @@ -353,7 +371,13 @@ class Model(nn.Module): kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed return self.predict(source, stream, **kwargs) - def predict(self, source=None, stream=False, predictor=None, **kwargs): + def predict( + self, + source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs, + ) -> list: """ Performs predictions on the given image source using the YOLO model. @@ -405,7 +429,13 @@ class Model(nn.Module): self.predictor.set_prompts(prompts) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) - def track(self, source=None, stream=False, persist=False, **kwargs): + def track( + self, + source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs, + ) -> list: """ Conducts object tracking on the specified input source using the registered trackers. @@ -438,7 +468,11 @@ class Model(nn.Module): kwargs["mode"] = "track" return self.predict(source=source, stream=stream, **kwargs) - def val(self, validator=None, **kwargs): + def val( + self, + validator=None, + **kwargs, + ): """ Validates the model using a specified dataset and validation configuration. @@ -471,7 +505,10 @@ class Model(nn.Module): self.metrics = validator.metrics return validator.metrics - def benchmark(self, **kwargs): + def benchmark( + self, + **kwargs, + ): """ Benchmarks the model across various export formats to evaluate performance. @@ -509,7 +546,10 @@ class Model(nn.Module): verbose=kwargs.get("verbose"), ) - def export(self, **kwargs): + def export( + self, + **kwargs, + ): """ Exports the model to a different format suitable for deployment. @@ -537,7 +577,11 @@ class Model(nn.Module): args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) - def train(self, trainer=None, **kwargs): + def train( + self, + trainer=None, + **kwargs, + ): """ Trains the model using the specified dataset and training configuration. @@ -607,7 +651,13 @@ class Model(nn.Module): self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP return self.metrics - def tune(self, use_ray=False, iterations=10, *args, **kwargs): + def tune( + self, + use_ray=False, + iterations=10, + *args, + **kwargs, + ): """ Conducts hyperparameter tuning for the model, with an option to use Ray Tune. @@ -640,7 +690,7 @@ class Model(nn.Module): args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) - def _apply(self, fn): + def _apply(self, fn) -> "Model": """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers.""" self._check_is_pytorch_model() self = super()._apply(fn) # noqa @@ -649,7 +699,7 @@ class Model(nn.Module): return self @property - def names(self): + def names(self) -> list: """ Retrieves the class names associated with the loaded model. @@ -664,7 +714,7 @@ class Model(nn.Module): return check_class_names(self.model.names) if hasattr(self.model, "names") else None @property - def device(self): + def device(self) -> torch.device: """ Retrieves the device on which the model's parameters are allocated. @@ -688,7 +738,7 @@ class Model(nn.Module): """ return self.model.transforms if hasattr(self.model, "transforms") else None - def add_callback(self, event: str, func): + def add_callback(self, event: str, func) -> None: """ Adds a callback function for a specified event. @@ -704,7 +754,7 @@ class Model(nn.Module): """ self.callbacks[event].append(func) - def clear_callback(self, event: str): + def clear_callback(self, event: str) -> None: """ Clears all callback functions registered for a specified event. @@ -718,7 +768,7 @@ class Model(nn.Module): """ self.callbacks[event] = [] - def reset_callbacks(self): + def reset_callbacks(self) -> None: """ Resets all callbacks to their default functions. @@ -729,7 +779,7 @@ class Model(nn.Module): self.callbacks[event] = [callbacks.default_callbacks[event][0]] @staticmethod - def _reset_ckpt_args(args): + def _reset_ckpt_args(args: dict) -> dict: """Reset arguments when loading a PyTorch model.""" include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model return {k: v for k, v in args.items() if k in include} @@ -739,7 +789,7 @@ class Model(nn.Module): # name = self.__class__.__name__ # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") - def _smart_load(self, key): + def _smart_load(self, key: str): """Load model/trainer/validator/predictor.""" try: return self.task_map[self.task][key] @@ -751,7 +801,7 @@ class Model(nn.Module): ) from e @property - def task_map(self): + def task_map(self) -> dict: """ Map head to model, trainer, validator, and predictor classes. diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index cf3e295b..86203e44 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -761,6 +761,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): for m in ensemble.modules(): if hasattr(m, "inplace"): m.inplace = inplace + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model if len(ensemble) == 1: @@ -794,6 +796,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): for m in model.modules(): if hasattr(m, "inplace"): m.inplace = inplace + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model and ckpt return model, ckpt diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py index 03ba200f..806f1b5e 100644 --- a/ultralytics/trackers/utils/gmc.py +++ b/ultralytics/trackers/utils/gmc.py @@ -18,9 +18,9 @@ class GMC: Attributes: method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. downscale (int): Factor by which to downscale the frames for processing. - prevFrame (np.array): Stores the previous frame for tracking. + prevFrame (np.ndarray): Stores the previous frame for tracking. prevKeyPoints (list): Stores the keypoints from the previous frame. - prevDescriptors (np.array): Stores the descriptors from the previous frame. + prevDescriptors (np.ndarray): Stores the descriptors from the previous frame. initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. Methods: @@ -82,11 +82,11 @@ class GMC: Apply object detection on a raw frame using specified method. Args: - raw_frame (np.array): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed. detections (list): List of detections to be used in the processing. Returns: - (np.array): Processed frame. + (np.ndarray): Processed frame. Examples: >>> gmc = GMC() @@ -108,10 +108,10 @@ class GMC: Apply ECC algorithm to a raw frame. Args: - raw_frame (np.array): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed. Returns: - (np.array): Processed frame. + (np.ndarray): Processed frame. Examples: >>> gmc = GMC() @@ -154,11 +154,11 @@ class GMC: Apply feature-based methods like ORB or SIFT to a raw frame. Args: - raw_frame (np.array): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed. detections (list): List of detections to be used in the processing. Returns: - (np.array): Processed frame. + (np.ndarray): Processed frame. Examples: >>> gmc = GMC() @@ -296,10 +296,10 @@ class GMC: Apply Sparse Optical Flow method to a raw frame. Args: - raw_frame (np.array): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed. Returns: - (np.array): Processed frame. + (np.ndarray): Processed frame. Examples: >>> gmc = GMC() diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 32a14c54..7d79df51 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -22,13 +22,13 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7): Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. Args: - box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes. - box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes. + box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes. + box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes. iou (bool): Calculate the standard iou if True else return inter_area/box2_area. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. Returns: - (np.array): A numpy array of shape (n, m) representing the intersection over box2 area. + (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area. """ # Get the coordinates of bounding boxes @@ -295,7 +295,7 @@ class ConfusionMatrix: Attributes: task (str): The type of task, either 'detect' or 'classify'. - matrix (np.array): The confusion matrix, with dimensions depending on the task. + matrix (np.ndarray): The confusion matrix, with dimensions depending on the task. nc (int): The number of classes. conf (float): The confidence threshold for detections. iou_thres (float): The Intersection over Union threshold. diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index da7e5851..979ba2d8 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -27,7 +27,7 @@ class Colors: Attributes: palette (list of tuple): List of RGB color values. n (int): The number of colors in the palette. - pose_palette (np.array): A specific color palette array with dtype np.uint8. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. """ def __init__(self):