Expand Model method type hinting (#8279)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-02-19 14:24:30 +01:00 committed by GitHub
parent fbed8499da
commit 42744a1717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 104 additions and 50 deletions

View File

@ -418,8 +418,8 @@ def min_index(arr1, arr2):
Find a pair of indexes with the shortest distance between two arrays of 2D points.
Args:
arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.
Returns:
(tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.

View File

@ -5,6 +5,10 @@ import sys
from pathlib import Path
from typing import Union
import PIL
import numpy as np
import torch
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
@ -78,7 +82,12 @@ class Model(nn.Module):
NotImplementedError: If a specific model task or mode is not supported.
"""
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None:
def __init__(
self,
model: Union[str, Path] = "yolov8n.pt",
task: str = None,
verbose: bool = False,
) -> None:
"""
Initializes a new instance of the YOLO model class.
@ -135,7 +144,12 @@ class Model(nn.Module):
self.model_name = model
def __call__(self, source=None, stream=False, **kwargs):
def __call__(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs,
) -> list:
"""
An alias for the predict method, enabling the model instance to be callable.
@ -143,8 +157,9 @@ class Model(nn.Module):
with the required arguments for prediction.
Args:
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None.
source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
Defaults to None.
stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
Defaults to False.
**kwargs (dict): Additional keyword arguments for configuring the prediction process.
@ -163,7 +178,7 @@ class Model(nn.Module):
return session if session.client.authenticated else None
@staticmethod
def is_triton_model(model):
def is_triton_model(model: str) -> bool:
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
from urllib.parse import urlsplit
@ -171,7 +186,7 @@ class Model(nn.Module):
return url.netloc and url.path and url.scheme in {"http", "grpc"}
@staticmethod
def is_hub_model(model):
def is_hub_model(model: str) -> bool:
"""Check if the provided model is a HUB model."""
return any(
(
@ -181,7 +196,7 @@ class Model(nn.Module):
)
)
def _new(self, cfg: str, task=None, model=None, verbose=False):
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
"""
Initializes a new model and infers the task type from the model definitions.
@ -202,7 +217,7 @@ class Model(nn.Module):
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
self.model.task = self.task
def _load(self, weights: str, task=None):
def _load(self, weights: str, task=None) -> None:
"""
Initializes a new model and infers the task type from the model head.
@ -224,7 +239,7 @@ class Model(nn.Module):
self.overrides["model"] = weights
self.overrides["task"] = self.task
def _check_is_pytorch_model(self):
def _check_is_pytorch_model(self) -> None:
"""Raises TypeError is model is not a PyTorch model."""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module)
@ -237,7 +252,7 @@ class Model(nn.Module):
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
)
def reset_weights(self):
def reset_weights(self) -> "Model":
"""
Resets the model parameters to randomly initialized values, effectively discarding all training information.
@ -259,7 +274,7 @@ class Model(nn.Module):
p.requires_grad = True
return self
def load(self, weights="yolov8n.pt"):
def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
"""
Loads parameters from the specified weights file into the model.
@ -281,24 +296,22 @@ class Model(nn.Module):
self.model.load(weights)
return self
def save(self, filename="model.pt"):
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
"""
Saves the current model state to a file.
This method exports the model's checkpoint (ckpt) to the specified filename.
Args:
filename (str): The name of the file to save the model to. Defaults to 'model.pt'.
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
Raises:
AssertionError: If the model is not a PyTorch model.
"""
self._check_is_pytorch_model()
import torch
torch.save(self.ckpt, filename)
def info(self, detailed=False, verbose=True):
def info(self, detailed: bool = False, verbose: bool = True):
"""
Logs or returns model information.
@ -330,7 +343,12 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self.model.fuse()
def embed(self, source=None, stream=False, **kwargs):
def embed(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs,
) -> list:
"""
Generates image embeddings based on the provided source.
@ -353,7 +371,13 @@ class Model(nn.Module):
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
return self.predict(source, stream, **kwargs)
def predict(self, source=None, stream=False, predictor=None, **kwargs):
def predict(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
predictor=None,
**kwargs,
) -> list:
"""
Performs predictions on the given image source using the YOLO model.
@ -405,7 +429,13 @@ class Model(nn.Module):
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs):
def track(
self,
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
stream: bool = False,
persist: bool = False,
**kwargs,
) -> list:
"""
Conducts object tracking on the specified input source using the registered trackers.
@ -438,7 +468,11 @@ class Model(nn.Module):
kwargs["mode"] = "track"
return self.predict(source=source, stream=stream, **kwargs)
def val(self, validator=None, **kwargs):
def val(
self,
validator=None,
**kwargs,
):
"""
Validates the model using a specified dataset and validation configuration.
@ -471,7 +505,10 @@ class Model(nn.Module):
self.metrics = validator.metrics
return validator.metrics
def benchmark(self, **kwargs):
def benchmark(
self,
**kwargs,
):
"""
Benchmarks the model across various export formats to evaluate performance.
@ -509,7 +546,10 @@ class Model(nn.Module):
verbose=kwargs.get("verbose"),
)
def export(self, **kwargs):
def export(
self,
**kwargs,
):
"""
Exports the model to a different format suitable for deployment.
@ -537,7 +577,11 @@ class Model(nn.Module):
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, trainer=None, **kwargs):
def train(
self,
trainer=None,
**kwargs,
):
"""
Trains the model using the specified dataset and training configuration.
@ -607,7 +651,13 @@ class Model(nn.Module):
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
def tune(
self,
use_ray=False,
iterations=10,
*args,
**kwargs,
):
"""
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
@ -640,7 +690,7 @@ class Model(nn.Module):
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
def _apply(self, fn):
def _apply(self, fn) -> "Model":
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
self._check_is_pytorch_model()
self = super()._apply(fn) # noqa
@ -649,7 +699,7 @@ class Model(nn.Module):
return self
@property
def names(self):
def names(self) -> list:
"""
Retrieves the class names associated with the loaded model.
@ -664,7 +714,7 @@ class Model(nn.Module):
return check_class_names(self.model.names) if hasattr(self.model, "names") else None
@property
def device(self):
def device(self) -> torch.device:
"""
Retrieves the device on which the model's parameters are allocated.
@ -688,7 +738,7 @@ class Model(nn.Module):
"""
return self.model.transforms if hasattr(self.model, "transforms") else None
def add_callback(self, event: str, func):
def add_callback(self, event: str, func) -> None:
"""
Adds a callback function for a specified event.
@ -704,7 +754,7 @@ class Model(nn.Module):
"""
self.callbacks[event].append(func)
def clear_callback(self, event: str):
def clear_callback(self, event: str) -> None:
"""
Clears all callback functions registered for a specified event.
@ -718,7 +768,7 @@ class Model(nn.Module):
"""
self.callbacks[event] = []
def reset_callbacks(self):
def reset_callbacks(self) -> None:
"""
Resets all callbacks to their default functions.
@ -729,7 +779,7 @@ class Model(nn.Module):
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
@staticmethod
def _reset_ckpt_args(args):
def _reset_ckpt_args(args: dict) -> dict:
"""Reset arguments when loading a PyTorch model."""
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
@ -739,7 +789,7 @@ class Model(nn.Module):
# name = self.__class__.__name__
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _smart_load(self, key):
def _smart_load(self, key: str):
"""Load model/trainer/validator/predictor."""
try:
return self.task_map[self.task][key]
@ -751,7 +801,7 @@ class Model(nn.Module):
) from e
@property
def task_map(self):
def task_map(self) -> dict:
"""
Map head to model, trainer, validator, and predictor classes.

View File

@ -761,6 +761,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
for m in ensemble.modules():
if hasattr(m, "inplace"):
m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model
if len(ensemble) == 1:
@ -794,6 +796,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
for m in model.modules():
if hasattr(m, "inplace"):
m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt
return model, ckpt

View File

@ -18,9 +18,9 @@ class GMC:
Attributes:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Factor by which to downscale the frames for processing.
prevFrame (np.array): Stores the previous frame for tracking.
prevFrame (np.ndarray): Stores the previous frame for tracking.
prevKeyPoints (list): Stores the keypoints from the previous frame.
prevDescriptors (np.array): Stores the descriptors from the previous frame.
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
Methods:
@ -82,11 +82,11 @@ class GMC:
Apply object detection on a raw frame using specified method.
Args:
raw_frame (np.array): The raw frame to be processed.
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.array): Processed frame.
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
@ -108,10 +108,10 @@ class GMC:
Apply ECC algorithm to a raw frame.
Args:
raw_frame (np.array): The raw frame to be processed.
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.array): Processed frame.
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
@ -154,11 +154,11 @@ class GMC:
Apply feature-based methods like ORB or SIFT to a raw frame.
Args:
raw_frame (np.array): The raw frame to be processed.
raw_frame (np.ndarray): The raw frame to be processed.
detections (list): List of detections to be used in the processing.
Returns:
(np.array): Processed frame.
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()
@ -296,10 +296,10 @@ class GMC:
Apply Sparse Optical Flow method to a raw frame.
Args:
raw_frame (np.array): The raw frame to be processed.
raw_frame (np.ndarray): The raw frame to be processed.
Returns:
(np.array): Processed frame.
(np.ndarray): Processed frame.
Examples:
>>> gmc = GMC()

View File

@ -22,13 +22,13 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
Args:
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
iou (bool): Calculate the standard iou if True else return inter_area/box2_area.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
"""
# Get the coordinates of bounding boxes
@ -295,7 +295,7 @@ class ConfusionMatrix:
Attributes:
task (str): The type of task, either 'detect' or 'classify'.
matrix (np.array): The confusion matrix, with dimensions depending on the task.
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
nc (int): The number of classes.
conf (float): The confidence threshold for detections.
iou_thres (float): The Intersection over Union threshold.

View File

@ -27,7 +27,7 @@ class Colors:
Attributes:
palette (list of tuple): List of RGB color values.
n (int): The number of colors in the palette.
pose_palette (np.array): A specific color palette array with dtype np.uint8.
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
"""
def __init__(self):