mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Expand Model
method type hinting (#8279)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fbed8499da
commit
42744a1717
@ -418,8 +418,8 @@ def min_index(arr1, arr2):
|
||||
Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
||||
|
||||
Args:
|
||||
arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
|
||||
arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
|
||||
arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
|
||||
arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.
|
||||
|
@ -5,6 +5,10 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
@ -78,7 +82,12 @@ class Model(nn.Module):
|
||||
NotImplementedError: If a specific model task or mode is not supported.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, Path] = "yolov8n.pt",
|
||||
task: str = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a new instance of the YOLO model class.
|
||||
|
||||
@ -135,7 +144,12 @@ class Model(nn.Module):
|
||||
|
||||
self.model_name = model
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
An alias for the predict method, enabling the model instance to be callable.
|
||||
|
||||
@ -143,8 +157,9 @@ class Model(nn.Module):
|
||||
with the required arguments for prediction.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
|
||||
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None.
|
||||
source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
|
||||
predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
|
||||
Defaults to None.
|
||||
stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
|
||||
Defaults to False.
|
||||
**kwargs (dict): Additional keyword arguments for configuring the prediction process.
|
||||
@ -163,7 +178,7 @@ class Model(nn.Module):
|
||||
return session if session.client.authenticated else None
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model):
|
||||
def is_triton_model(model: str) -> bool:
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
@ -171,7 +186,7 @@ class Model(nn.Module):
|
||||
return url.netloc and url.path and url.scheme in {"http", "grpc"}
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
def is_hub_model(model: str) -> bool:
|
||||
"""Check if the provided model is a HUB model."""
|
||||
return any(
|
||||
(
|
||||
@ -181,7 +196,7 @@ class Model(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False):
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
@ -202,7 +217,7 @@ class Model(nn.Module):
|
||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
@ -224,7 +239,7 @@ class Model(nn.Module):
|
||||
self.overrides["model"] = weights
|
||||
self.overrides["task"] = self.task
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""Raises TypeError is model is not a PyTorch model."""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
@ -237,7 +252,7 @@ class Model(nn.Module):
|
||||
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
||||
)
|
||||
|
||||
def reset_weights(self):
|
||||
def reset_weights(self) -> "Model":
|
||||
"""
|
||||
Resets the model parameters to randomly initialized values, effectively discarding all training information.
|
||||
|
||||
@ -259,7 +274,7 @@ class Model(nn.Module):
|
||||
p.requires_grad = True
|
||||
return self
|
||||
|
||||
def load(self, weights="yolov8n.pt"):
|
||||
def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
|
||||
"""
|
||||
Loads parameters from the specified weights file into the model.
|
||||
|
||||
@ -281,24 +296,22 @@ class Model(nn.Module):
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def save(self, filename="model.pt"):
|
||||
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
|
||||
"""
|
||||
Saves the current model state to a file.
|
||||
|
||||
This method exports the model's checkpoint (ckpt) to the specified filename.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to save the model to. Defaults to 'model.pt'.
|
||||
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
import torch
|
||||
|
||||
torch.save(self.ckpt, filename)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
def info(self, detailed: bool = False, verbose: bool = True):
|
||||
"""
|
||||
Logs or returns model information.
|
||||
|
||||
@ -330,7 +343,12 @@ class Model(nn.Module):
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def embed(self, source=None, stream=False, **kwargs):
|
||||
def embed(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Generates image embeddings based on the provided source.
|
||||
|
||||
@ -353,7 +371,13 @@ class Model(nn.Module):
|
||||
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||
def predict(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
predictor=None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Performs predictions on the given image source using the YOLO model.
|
||||
|
||||
@ -405,7 +429,13 @@ class Model(nn.Module):
|
||||
self.predictor.set_prompts(prompts)
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||
def track(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
persist: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Conducts object tracking on the specified input source using the registered trackers.
|
||||
|
||||
@ -438,7 +468,11 @@ class Model(nn.Module):
|
||||
kwargs["mode"] = "track"
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
def val(self, validator=None, **kwargs):
|
||||
def val(
|
||||
self,
|
||||
validator=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Validates the model using a specified dataset and validation configuration.
|
||||
|
||||
@ -471,7 +505,10 @@ class Model(nn.Module):
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
def benchmark(self, **kwargs):
|
||||
def benchmark(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Benchmarks the model across various export formats to evaluate performance.
|
||||
|
||||
@ -509,7 +546,10 @@ class Model(nn.Module):
|
||||
verbose=kwargs.get("verbose"),
|
||||
)
|
||||
|
||||
def export(self, **kwargs):
|
||||
def export(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Exports the model to a different format suitable for deployment.
|
||||
|
||||
@ -537,7 +577,11 @@ class Model(nn.Module):
|
||||
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, trainer=None, **kwargs):
|
||||
def train(
|
||||
self,
|
||||
trainer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Trains the model using the specified dataset and training configuration.
|
||||
|
||||
@ -607,7 +651,13 @@ class Model(nn.Module):
|
||||
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||
return self.metrics
|
||||
|
||||
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
||||
def tune(
|
||||
self,
|
||||
use_ray=False,
|
||||
iterations=10,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
|
||||
|
||||
@ -640,7 +690,7 @@ class Model(nn.Module):
|
||||
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||
|
||||
def _apply(self, fn):
|
||||
def _apply(self, fn) -> "Model":
|
||||
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
|
||||
self._check_is_pytorch_model()
|
||||
self = super()._apply(fn) # noqa
|
||||
@ -649,7 +699,7 @@ class Model(nn.Module):
|
||||
return self
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
def names(self) -> list:
|
||||
"""
|
||||
Retrieves the class names associated with the loaded model.
|
||||
|
||||
@ -664,7 +714,7 @@ class Model(nn.Module):
|
||||
return check_class_names(self.model.names) if hasattr(self.model, "names") else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Retrieves the device on which the model's parameters are allocated.
|
||||
|
||||
@ -688,7 +738,7 @@ class Model(nn.Module):
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, "transforms") else None
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
def add_callback(self, event: str, func) -> None:
|
||||
"""
|
||||
Adds a callback function for a specified event.
|
||||
|
||||
@ -704,7 +754,7 @@ class Model(nn.Module):
|
||||
"""
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
def clear_callback(self, event: str):
|
||||
def clear_callback(self, event: str) -> None:
|
||||
"""
|
||||
Clears all callback functions registered for a specified event.
|
||||
|
||||
@ -718,7 +768,7 @@ class Model(nn.Module):
|
||||
"""
|
||||
self.callbacks[event] = []
|
||||
|
||||
def reset_callbacks(self):
|
||||
def reset_callbacks(self) -> None:
|
||||
"""
|
||||
Resets all callbacks to their default functions.
|
||||
|
||||
@ -729,7 +779,7 @@ class Model(nn.Module):
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
def _reset_ckpt_args(args: dict) -> dict:
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
@ -739,7 +789,7 @@ class Model(nn.Module):
|
||||
# name = self.__class__.__name__
|
||||
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _smart_load(self, key):
|
||||
def _smart_load(self, key: str):
|
||||
"""Load model/trainer/validator/predictor."""
|
||||
try:
|
||||
return self.task_map[self.task][key]
|
||||
@ -751,7 +801,7 @@ class Model(nn.Module):
|
||||
) from e
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Map head to model, trainer, validator, and predictor classes.
|
||||
|
||||
|
@ -761,6 +761,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
for m in ensemble.modules():
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
if len(ensemble) == 1:
|
||||
@ -794,6 +796,8 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
for m in model.modules():
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model and ckpt
|
||||
return model, ckpt
|
||||
|
@ -18,9 +18,9 @@ class GMC:
|
||||
Attributes:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Factor by which to downscale the frames for processing.
|
||||
prevFrame (np.array): Stores the previous frame for tracking.
|
||||
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
||||
prevDescriptors (np.array): Stores the descriptors from the previous frame.
|
||||
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||
|
||||
Methods:
|
||||
@ -82,11 +82,11 @@ class GMC:
|
||||
Apply object detection on a raw frame using specified method.
|
||||
|
||||
Args:
|
||||
raw_frame (np.array): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.array): Processed frame.
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
@ -108,10 +108,10 @@ class GMC:
|
||||
Apply ECC algorithm to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.array): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.array): Processed frame.
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
@ -154,11 +154,11 @@ class GMC:
|
||||
Apply feature-based methods like ORB or SIFT to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.array): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.array): Processed frame.
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
@ -296,10 +296,10 @@ class GMC:
|
||||
Apply Sparse Optical Flow method to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.array): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
|
||||
Returns:
|
||||
(np.array): Processed frame.
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
|
@ -22,13 +22,13 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
||||
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
||||
|
||||
Args:
|
||||
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
|
||||
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
|
||||
box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
|
||||
box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
|
||||
iou (bool): Calculate the standard iou if True else return inter_area/box2_area.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
|
||||
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
@ -295,7 +295,7 @@ class ConfusionMatrix:
|
||||
|
||||
Attributes:
|
||||
task (str): The type of task, either 'detect' or 'classify'.
|
||||
matrix (np.array): The confusion matrix, with dimensions depending on the task.
|
||||
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
||||
nc (int): The number of classes.
|
||||
conf (float): The confidence threshold for detections.
|
||||
iou_thres (float): The Intersection over Union threshold.
|
||||
|
@ -27,7 +27,7 @@ class Colors:
|
||||
Attributes:
|
||||
palette (list of tuple): List of RGB color values.
|
||||
n (int): The number of colors in the palette.
|
||||
pose_palette (np.array): A specific color palette array with dtype np.uint8.
|
||||
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user