mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Docstring additions (#122)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c9f3e469cb
commit
df4fc14c10
@ -436,7 +436,9 @@ class LetterBox:
|
||||
self.scaleup = scaleup
|
||||
self.stride = stride
|
||||
|
||||
def __call__(self, labels={}, image=None):
|
||||
def __call__(self, labels=None, image=None):
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get("img") if image is None else image
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||
|
@ -116,14 +116,31 @@ def try_export(inner_func):
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
Exporter
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the exporter.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
project = self.args.project or f"runs/{self.args.task}"
|
||||
name = self.args.name or "exp" # hardcode mode as export doesn't require it
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.imgsz = self.args.imgsz
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
@ -143,7 +160,7 @@ class Exporter:
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
||||
# Checks
|
||||
self.imgsz = check_imgsz(self.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
@ -9,7 +9,7 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
# map head: [model, trainer, validator, predictor]
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
"classify": [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
@ -24,39 +24,44 @@ MODEL_MAP = {
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
Python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
YOLO
|
||||
|
||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
__init_key = object()
|
||||
__init_key = object() # used to ensure proper initialization
|
||||
|
||||
def __init__(self, init_key=None, type="v8") -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
|
||||
Args:
|
||||
type (str): Type/version of models to use
|
||||
init_key (object): used to ensure proper initialization. Defaults to None.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
"""
|
||||
if init_key != YOLO.__init_key:
|
||||
raise SyntaxError(HELP_MSG)
|
||||
|
||||
self.type = type
|
||||
self.ModelClass = None
|
||||
self.TrainerClass = None
|
||||
self.ValidatorClass = None
|
||||
self.PredictorClass = None
|
||||
self.model = None
|
||||
self.trainer = None
|
||||
self.task = None
|
||||
self.ModelClass = None # model class
|
||||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.overrides = {}
|
||||
self.init_disabled = False
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.init_disabled = False # disable model initialization
|
||||
|
||||
@classmethod
|
||||
def new(cls, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
verbsoe (bool): display model info on load
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg) # model dict
|
||||
|
@ -41,8 +41,36 @@ from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mo
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
BasePredictor
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
A base class for creating predictors.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_setup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
view_img (bool): Whether to view image output.
|
||||
annotator (Annotator): Annotator used for prediction.
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
project = self.args.project or f"runs/{self.args.task}"
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
|
@ -33,9 +33,53 @@ from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = get_config(cfg, overrides)
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to last checkpoint.
|
||||
best (Path): Path to best checkpoint.
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.check_resume()
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
@ -464,6 +508,19 @@ class BaseTrainer:
|
||||
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
name (str): name of the optimizer to use
|
||||
lr (float): learning rate
|
||||
momentum (float): momentum
|
||||
decay (float): weight decay
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: the built optimizer
|
||||
"""
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
for v in model.modules():
|
||||
|
@ -16,10 +16,36 @@ from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
Base validator class.
|
||||
BaseValidator
|
||||
|
||||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
logger (logging.Logger): Logger to use for validation.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
logger (logging.Logger): Logger to log messages.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
|
@ -4,11 +4,11 @@ import logging.config
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import IPython
|
||||
import pandas as pd
|
||||
|
||||
# Constants
|
||||
@ -25,22 +25,25 @@ TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
||||
LOGGING_NAME = 'yolov5'
|
||||
HELP_MSG = \
|
||||
"""
|
||||
Please refer to below Usage examples for help running YOLOv8:
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
||||
Install:
|
||||
pip install ultralytics
|
||||
|
||||
Python SDK:
|
||||
2. Use the Python SDK:
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO.new('yolov8n.yaml') # create a new model from scratch
|
||||
model = YOLO.load('yolov8n.pt') # load a pretrained model (recommended for best training results)
|
||||
results = model.train(data='coco128.yaml')
|
||||
results = model.val()
|
||||
results = model.predict(source='bus.jpg')
|
||||
success = model.export(format='onnx')
|
||||
results = model.train(data='coco128.yaml') # train the model
|
||||
results = model.val() # evaluate model performance on the validation set
|
||||
results = model.predict(source='bus.jpg') # predict on an image
|
||||
success = model.export(format='onnx') # export the model to ONNX format
|
||||
|
||||
3. Use the command line interface (CLI):
|
||||
|
||||
CLI:
|
||||
yolo task=detect mode=train model=yolov8n.yaml args...
|
||||
classify predict yolov8n-cls.yaml args...
|
||||
segment val yolov8n-seg.yaml args...
|
||||
@ -60,41 +63,67 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
|
||||
|
||||
def is_colab():
|
||||
# Is environment a Google Colab instance?
|
||||
"""
|
||||
Check if the current script is running inside a Google Colab notebook.
|
||||
|
||||
Returns:
|
||||
bool: True if running inside a Colab notebook, False otherwise.
|
||||
"""
|
||||
# Check if the google.colab module is present in sys.modules
|
||||
return 'google.colab' in sys.modules
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
# Is environment a Kaggle Notebook?
|
||||
"""
|
||||
Check if the current script is running inside a Kaggle kernel.
|
||||
|
||||
Returns:
|
||||
bool: True if running inside a Kaggle kernel, False otherwise.
|
||||
"""
|
||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||
|
||||
|
||||
def is_notebook():
|
||||
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
|
||||
ipython_type = str(type(IPython.get_ipython()))
|
||||
return 'colab' in ipython_type or 'zmqshell' in ipython_type
|
||||
def is_jupyter_notebook():
|
||||
"""
|
||||
Check if the current script is running inside a Jupyter Notebook.
|
||||
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
|
||||
|
||||
def is_docker() -> bool:
|
||||
"""Check if the process runs inside a docker container."""
|
||||
if Path("/.dockerenv").exists():
|
||||
return True
|
||||
try: # check if docker is in control groups
|
||||
with open("/proc/self/cgroup") as file:
|
||||
return any("docker" in line for line in file)
|
||||
except OSError:
|
||||
Returns:
|
||||
bool: True if running inside a Jupyter Notebook, False otherwise.
|
||||
"""
|
||||
# Check if the get_ipython function exists
|
||||
# (it does not exist when running as a standalone script)
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
return get_ipython() is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def is_writeable(dir, test=False):
|
||||
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
||||
if not test:
|
||||
return os.access(dir, os.W_OK) # possible issues on Windows
|
||||
file = Path(dir) / 'tmp.txt'
|
||||
def is_docker() -> bool:
|
||||
"""
|
||||
Determine if the script is running inside a Docker container.
|
||||
|
||||
Returns:
|
||||
bool: True if the script is running inside a Docker container, False otherwise.
|
||||
"""
|
||||
with open('/proc/self/cgroup') as f:
|
||||
return 'docker' in f.read()
|
||||
|
||||
|
||||
def is_dir_writeable(dir_path: str) -> bool:
|
||||
"""
|
||||
Check if a directory is writeable.
|
||||
|
||||
Args:
|
||||
dir_path (str): The path to the directory.
|
||||
|
||||
Returns:
|
||||
bool: True if the directory is writeable, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with open(file, 'w'): # open file with write permissions
|
||||
with tempfile.TemporaryFile(dir=dir_path):
|
||||
pass
|
||||
file.unlink() # remove file
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
@ -106,20 +135,40 @@ def get_default_args(func):
|
||||
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
||||
|
||||
|
||||
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
|
||||
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
|
||||
env = os.getenv(env_var)
|
||||
if env:
|
||||
path = Path(env) # use environment variable
|
||||
def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
"""
|
||||
Get the user config directory.
|
||||
|
||||
Args:
|
||||
sub_dir (str): The name of the subdirectory to create.
|
||||
|
||||
Returns:
|
||||
Path: The path to the user config directory.
|
||||
"""
|
||||
# Get the operating system name
|
||||
os_name = platform.system()
|
||||
|
||||
# Return the appropriate config directory for each operating system
|
||||
if os_name == 'Windows':
|
||||
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
|
||||
elif os_name == 'Darwin': # macOS
|
||||
path = Path.home() / 'Library' / 'Application Support' / sub_dir
|
||||
elif os_name == 'Linux':
|
||||
path = Path.home() / '.config' / sub_dir
|
||||
else:
|
||||
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
|
||||
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
|
||||
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
|
||||
path.mkdir(exist_ok=True) # make if required
|
||||
raise ValueError(f'Unsupported operating system: {os_name}')
|
||||
|
||||
# GCP and AWS lambda fix, only /tmp is writeable
|
||||
if not is_dir_writeable(path.parent):
|
||||
path = Path('/tmp') / sub_dir
|
||||
|
||||
# Create the subdirectory if it does not exist
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
USER_CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
||||
USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir
|
||||
|
||||
|
||||
def emojis(str=''):
|
||||
|
@ -12,7 +12,7 @@ import pkg_resources as pkg
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||
is_docker, is_notebook)
|
||||
is_docker, is_jupyter_notebook)
|
||||
from ultralytics.yolo.utils.ops import make_divisible
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ def check_yaml(file, suffix=('.yaml', '.yml')):
|
||||
def check_imshow(warn=False):
|
||||
# Check if environment supports image displays
|
||||
try:
|
||||
assert not is_notebook()
|
||||
assert not is_jupyter_notebook()
|
||||
assert not is_docker()
|
||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||
cv2.waitKey(1)
|
||||
|
@ -24,8 +24,21 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||
|
||||
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
"""
|
||||
Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
# TODO: docs
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
||||
directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
path (str or pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
|
||||
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
|
||||
|
||||
Returns:
|
||||
pathlib.Path: Incremented path.
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
|
@ -100,10 +100,31 @@ def non_max_suppression(
|
||||
max_det=300,
|
||||
nm=0, # number of masks
|
||||
):
|
||||
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
||||
Arguments:
|
||||
prediction (torch.Tensor): A tensor of shape (batch_size, num_boxes, num_classes + 4 + num_masks)
|
||||
containing the predicted boxes, classes, and masks. The tensor should be in the format
|
||||
output by a model, such as YOLO.
|
||||
conf_thres (float): The confidence threshold below which boxes will be filtered out.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
|
||||
agnostic (bool): If True, the model is agnostic to the number of classes, and all
|
||||
classes will be considered as one.
|
||||
multi_label (bool): If True, each box may have multiple labels.
|
||||
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
|
||||
list contains the apriori labels for a given image. The list should be in the format
|
||||
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
|
||||
max_det (int): The maximum number of boxes to keep after NMS.
|
||||
nm (int): The number of masks output by the model.
|
||||
|
||||
Returns:
|
||||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
List[torch.Tensor]: A list of length batch_size, where each element is a tensor of
|
||||
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
||||
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
"""
|
||||
|
||||
# Checks
|
||||
|
Loading…
x
Reference in New Issue
Block a user