Expand Model method type hinting (#8279)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-02-19 14:24:30 +01:00 committed by GitHub
parent fbed8499da
commit 42744a1717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 104 additions and 50 deletions

View File

@ -418,8 +418,8 @@ def min_index(arr1, arr2):
Find a pair of indexes with the shortest distance between two arrays of 2D points. Find a pair of indexes with the shortest distance between two arrays of 2D points.
Args: Args:
arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points. arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points. arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.
Returns: Returns:
(tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively. (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.

View File

@ -5,6 +5,10 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import PIL
import numpy as np
import torch
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
@ -78,7 +82,12 @@ class Model(nn.Module):
NotImplementedError: If a specific model task or mode is not supported. NotImplementedError: If a specific model task or mode is not supported.
""" """
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None: def __init__(
self,
model: Union[str, Path] = "yolov8n.pt",
task: str = None,
verbose: bool = False,
) -> None:
""" """
Initializes a new instance of the YOLO model class. Initializes a new instance of the YOLO model class.
@ -135,7 +144,12 @@ class Model(nn.Module):
self.model_name = model self.model_name = model
def __call__(self, source=None, stream=False, **kwargs): def __call__(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs,
) -> list:
""" """
An alias for the predict method, enabling the model instance to be callable. An alias for the predict method, enabling the model instance to be callable.
@ -143,8 +157,9 @@ class Model(nn.Module):
with the required arguments for prediction. with the required arguments for prediction.
Args: Args:
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions. source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None. predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
Defaults to None.
stream (bool, optional): If True, treats the input source as a continuous stream for predictions. stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
Defaults to False. Defaults to False.
**kwargs (dict): Additional keyword arguments for configuring the prediction process. **kwargs (dict): Additional keyword arguments for configuring the prediction process.
@ -163,7 +178,7 @@ class Model(nn.Module):
return session if session.client.authenticated else None return session if session.client.authenticated else None
@staticmethod @staticmethod
def is_triton_model(model): def is_triton_model(model: str) -> bool:
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>""" """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
from urllib.parse import urlsplit from urllib.parse import urlsplit
@ -171,7 +186,7 @@ class Model(nn.Module):
return url.netloc and url.path and url.scheme in {"http", "grpc"} return url.netloc and url.path and url.scheme in {"http", "grpc"}
@staticmethod @staticmethod
def is_hub_model(model): def is_hub_model(model: str) -> bool:
"""Check if the provided model is a HUB model.""" """Check if the provided model is a HUB model."""
return any( return any(
( (
@ -181,7 +196,7 @@ class Model(nn.Module):
) )
) )
def _new(self, cfg: str, task=None, model=None, verbose=False): def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
""" """
Initializes a new model and infers the task type from the model definitions. Initializes a new model and infers the task type from the model definitions.
@ -202,7 +217,7 @@ class Model(nn.Module):
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
self.model.task = self.task self.model.task = self.task
def _load(self, weights: str, task=None): def _load(self, weights: str, task=None) -> None:
""" """
Initializes a new model and infers the task type from the model head. Initializes a new model and infers the task type from the model head.
@ -224,7 +239,7 @@ class Model(nn.Module):
self.overrides["model"] = weights self.overrides["model"] = weights
self.overrides["task"] = self.task self.overrides["task"] = self.task
def _check_is_pytorch_model(self): def _check_is_pytorch_model(self) -> None:
"""Raises TypeError is model is not a PyTorch model.""" """Raises TypeError is model is not a PyTorch model."""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module) pt_module = isinstance(self.model, nn.Module)
@ -237,7 +252,7 @@ class Model(nn.Module):
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
) )
def reset_weights(self): def reset_weights(self) -> "Model":
""" """
Resets the model parameters to randomly initialized values, effectively discarding all training information. Resets the model parameters to randomly initialized values, effectively discarding all training information.
@ -259,7 +274,7 @@ class Model(nn.Module):
p.requires_grad = True p.requires_grad = True
return self return self
def load(self, weights="yolov8n.pt"): def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
""" """
Loads parameters from the specified weights file into the model. Loads parameters from the specified weights file into the model.
@ -281,24 +296,22 @@ class Model(nn.Module):
self.model.load(weights) self.model.load(weights)
return self return self
def save(self, filename="model.pt"): def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
""" """
Saves the current model state to a file. Saves the current model state to a file.
This method exports the model's checkpoint (ckpt) to the specified filename. This method exports the model's checkpoint (ckpt) to the specified filename.
Args: Args:
filename (str): The name of the file to save the model to. Defaults to 'model.pt'. filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
Raises: Raises:
AssertionError: If the model is not a PyTorch model. AssertionError: If the model is not a PyTorch model.
""" """
self._check_is_pytorch_model() self._check_is_pytorch_model()
import torch
torch.save(self.ckpt, filename) torch.save(self.ckpt, filename)
def info(self, detailed=False, verbose=True): def info(self, detailed: bool = False, verbose: bool = True):
""" """
Logs or returns model information. Logs or returns model information.
@ -330,7 +343,12 @@ class Model(nn.Module):
self._check_is_pytorch_model() self._check_is_pytorch_model()
self.model.fuse() self.model.fuse()
def embed(self, source=None, stream=False, **kwargs): def embed(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs,
) -> list:
""" """
Generates image embeddings based on the provided source. Generates image embeddings based on the provided source.
@ -353,7 +371,13 @@ class Model(nn.Module):
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
return self.predict(source, stream, **kwargs) return self.predict(source, stream, **kwargs)
def predict(self, source=None, stream=False, predictor=None, **kwargs): def predict(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
predictor=None,
**kwargs,
) -> list:
""" """
Performs predictions on the given image source using the YOLO model. Performs predictions on the given image source using the YOLO model.
@ -405,7 +429,13 @@ class Model(nn.Module):
self.predictor.set_prompts(prompts) self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs): def track(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
persist: bool = False,
**kwargs,
) -> list:
""" """
Conducts object tracking on the specified input source using the registered trackers. Conducts object tracking on the specified input source using the registered trackers.
@ -438,7 +468,11 @@ class Model(nn.Module):
kwargs["mode"] = "track" kwargs["mode"] = "track"
return self.predict(source=source, stream=stream, **kwargs) return self.predict(source=source, stream=stream, **kwargs)
def val(self, validator=None, **kwargs): def val(
self,
validator=None,
**kwargs,
):
""" """
Validates the model using a specified dataset and validation configuration. Validates the model using a specified dataset and validation configuration.
@ -471,7 +505,10 @@ class Model(nn.Module):
self.metrics = validator.metrics self.metrics = validator.metrics
return validator.metrics return validator.metrics
def benchmark(self, **kwargs): def benchmark(
self,
**kwargs,
):
""" """
Benchmarks the model across various export formats to evaluate performance. Benchmarks the model across various export formats to evaluate performance.
@ -509,7 +546,10 @@ class Model(nn.Module):
verbose=kwargs.get("verbose"), verbose=kwargs.get("verbose"),
) )
def export(self, **kwargs): def export(
self,
**kwargs,
):
""" """
Exports the model to a different format suitable for deployment. Exports the model to a different format suitable for deployment.
@ -537,7 +577,11 @@ class Model(nn.Module):
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, trainer=None, **kwargs): def train(
self,
trainer=None,
**kwargs,
):
""" """
Trains the model using the specified dataset and training configuration. Trains the model using the specified dataset and training configuration.
@ -607,7 +651,13 @@ class Model(nn.Module):
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs): def tune(
self,
use_ray=False,
iterations=10,
*args,
**kwargs,
):
""" """
Conducts hyperparameter tuning for the model, with an option to use Ray Tune. Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
@ -640,7 +690,7 @@ class Model(nn.Module):
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def _apply(self, fn): def _apply(self, fn) -> "Model":
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers.""" """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
self._check_is_pytorch_model() self._check_is_pytorch_model()
self = super()._apply(fn) # noqa self = super()._apply(fn) # noqa
@ -649,7 +699,7 @@ class Model(nn.Module):
return self return self
@property @property
def names(self): def names(self) -> list:
""" """
Retrieves the class names associated with the loaded model. Retrieves the class names associated with the loaded model.
@ -664,7 +714,7 @@ class Model(nn.Module):
return check_class_names(self.model.names) if hasattr(self.model, "names") else None return check_class_names(self.model.names) if hasattr(self.model, "names") else None
@property @property
def device(self): def device(self) -> torch.device:
""" """
Retrieves the device on which the model's parameters are allocated. Retrieves the device on which the model's parameters are allocated.
@ -688,7 +738,7 @@ class Model(nn.Module):
""" """
return self.model.transforms if hasattr(self.model, "transforms") else None return self.model.transforms if hasattr(self.model, "transforms") else None
def add_callback(self, event: str, func): def add_callback(self, event: str, func) -> None:
""" """
Adds a callback function for a specified event. Adds a callback function for a specified event.
@ -704,7 +754,7 @@ class Model(nn.Module):
""" """
self.callbacks[event].append(func) self.callbacks[event].append(func)
def clear_callback(self, event: str): def clear_callback(self, event: str) -> None:
""" """
Clears all callback functions registered for a specified event. Clears all callback functions registered for a specified event.
@ -718,7 +768,7 @@ class Model(nn.Module):
""" """
self.callbacks[event] = [] self.callbacks[event] = []
def reset_callbacks(self): def reset_callbacks(self) -> None:
""" """
Resets all callbacks to their default functions. Resets all callbacks to their default functions.
@ -729,7 +779,7 @@ class Model(nn.Module):
self.callbacks[event] = [callbacks.default_callbacks[event][0]] self.callbacks[event] = [callbacks.default_callbacks[event][0]]
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args: dict) -> dict:
"""Reset arguments when loading a PyTorch model.""" """Reset arguments when loading a PyTorch model."""
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include} return {k: v for k, v in args.items() if k in include}
@ -739,7 +789,7 @@ class Model(nn.Module):
# name = self.__class__.__name__ # name = self.__class__.__name__
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _smart_load(self, key): def _smart_load(self, key: str):
"""Load model/trainer/validator/predictor.""" """Load model/trainer/validator/predictor."""
try: try:
return self.task_map[self.task][key] return self.task_map[self.task][key]
@ -751,7 +801,7 @@ class Model(nn.Module):
) from e ) from e
@property @property
def task_map(self): def task_map(self) -> dict:
""" """
Map head to model, trainer, validator, and predictor classes. Map head to model, trainer, validator, and predictor classes.

View File

@ -761,6 +761,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
for m in ensemble.modules(): for m in ensemble.modules():
if hasattr(m, "inplace"): if hasattr(m, "inplace"):
m.inplace = inplace m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model # Return model
if len(ensemble) == 1: if len(ensemble) == 1:
@ -794,6 +796,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
for m in model.modules(): for m in model.modules():
if hasattr(m, "inplace"): if hasattr(m, "inplace"):
m.inplace = inplace m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt # Return model and ckpt
return model, ckpt return model, ckpt

View File

@ -18,9 +18,9 @@ class GMC:
Attributes: Attributes:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Factor by which to downscale the frames for processing. downscale (int): Factor by which to downscale the frames for processing.
prevFrame (np.array): Stores the previous frame for tracking. prevFrame (np.ndarray): Stores the previous frame for tracking.
prevKeyPoints (list): Stores the keypoints from the previous frame. prevKeyPoints (list): Stores the keypoints from the previous frame.
prevDescriptors (np.array): Stores the descriptors from the previous frame. prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
Methods: Methods:
@ -82,11 +82,11 @@ class GMC:
Apply object detection on a raw frame using specified method. Apply object detection on a raw frame using specified method.
Args: Args:
raw_frame (np.array): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing. detections (list): List of detections to be used in the processing.
Returns: Returns:
(np.array): Processed frame. (np.ndarray): Processed frame.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC()
@ -108,10 +108,10 @@ class GMC:
Apply ECC algorithm to a raw frame. Apply ECC algorithm to a raw frame.
Args: Args:
raw_frame (np.array): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed.
Returns: Returns:
(np.array): Processed frame. (np.ndarray): Processed frame.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC()
@ -154,11 +154,11 @@ class GMC:
Apply feature-based methods like ORB or SIFT to a raw frame. Apply feature-based methods like ORB or SIFT to a raw frame.
Args: Args:
raw_frame (np.array): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing. detections (list): List of detections to be used in the processing.
Returns: Returns:
(np.array): Processed frame. (np.ndarray): Processed frame.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC()
@ -296,10 +296,10 @@ class GMC:
Apply Sparse Optical Flow method to a raw frame. Apply Sparse Optical Flow method to a raw frame.
Args: Args:
raw_frame (np.array): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed.
Returns: Returns:
(np.array): Processed frame. (np.ndarray): Processed frame.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC()

View File

@ -22,13 +22,13 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
Args: Args:
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes. box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes. box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
iou (bool): Calculate the standard iou if True else return inter_area/box2_area. iou (bool): Calculate the standard iou if True else return inter_area/box2_area.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns: Returns:
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area. (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
""" """
# Get the coordinates of bounding boxes # Get the coordinates of bounding boxes
@ -295,7 +295,7 @@ class ConfusionMatrix:
Attributes: Attributes:
task (str): The type of task, either 'detect' or 'classify'. task (str): The type of task, either 'detect' or 'classify'.
matrix (np.array): The confusion matrix, with dimensions depending on the task. matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
nc (int): The number of classes. nc (int): The number of classes.
conf (float): The confidence threshold for detections. conf (float): The confidence threshold for detections.
iou_thres (float): The Intersection over Union threshold. iou_thres (float): The Intersection over Union threshold.

View File

@ -27,7 +27,7 @@ class Colors:
Attributes: Attributes:
palette (list of tuple): List of RGB color values. palette (list of tuple): List of RGB color values.
n (int): The number of colors in the palette. n (int): The number of colors in the palette.
pose_palette (np.array): A specific color palette array with dtype np.uint8. pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
""" """
def __init__(self): def __init__(self):