From 8c158823e10aa185cc8842bbafbb924c6bb803b4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Jan 2024 01:06:14 +0100 Subject: [PATCH] `ultralytics 8.1.8` new `model.save('filename.pt')` method (#7886) Signed-off-by: Glenn Jocher --- docs/overrides/partials/comments.html | 3 +- mkdocs.yml | 2 +- ultralytics/__init__.py | 2 +- ultralytics/engine/model.py | 436 +++++++++++++++++++++----- ultralytics/nn/tasks.py | 13 +- 5 files changed, 370 insertions(+), 86 deletions(-) diff --git a/docs/overrides/partials/comments.html b/docs/overrides/partials/comments.html index 57050a15..0479b37f 100644 --- a/docs/overrides/partials/comments.html +++ b/docs/overrides/partials/comments.html @@ -9,7 +9,8 @@ data-emit-metadata="0" data-input-position="top" data-lang="en" - data-mapping="title" + data-loading="lazy" + data-mapping="pathname" data-reactions-enabled="1" data-repo="ultralytics/ultralytics" data-repo-id="R_kgDOH-jzvQ" diff --git a/mkdocs.yml b/mkdocs.yml index 71cb5e10..ae95cbc8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -179,7 +179,7 @@ nav: - NEW 🚀 Explorer: - datasets/explorer/index.md - Languages: - - đŸ‡Ŧ🇧  English: https://docs.ultralytics.com/ + - đŸ‡Ŧ🇧  English: https://ultralytics.com/docs/ - đŸ‡¨đŸ‡ŗ  įŽ€äŊ“中文: https://docs.ultralytics.com/zh/ - 🇰🇷  한ęĩ­ė–´: https://docs.ultralytics.com/ko/ - đŸ‡¯đŸ‡ĩ  æ—ĨæœŦčĒž: https://docs.ultralytics.com/ja/ diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index f7b77615..543cd034 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.7" +__version__ = "8.1.8" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 62560703..1a85913e 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -6,60 +6,98 @@ from pathlib import Path from typing import Union from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load -from ultralytics.hub.utils import HUB_WEB_ROOT class Model(nn.Module): """ - A base class to unify APIs for all models. + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and + extendable for different tasks and model configurations. Args: - model (str, Path): Path to the model file to load or create. - task (Any, optional): Task type for the YOLO model. Defaults to None. + model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file + path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'. + task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's + application domain, such as object detection, segmentation, etc. Defaults to None. + verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False. Attributes: - predictor (Any): The predictor object. - model (Any): The model object. - trainer (Any): The trainer object. - task (str): The type of model task. - ckpt (Any): The checkpoint object if the model loaded from *.pt file. - cfg (str): The model configuration if loaded from *.yaml file. - ckpt_path (str): The checkpoint file path. - overrides (dict): Overrides for the trainer object. - metrics (Any): The data for metrics. + callbacks (dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (dict): A dictionary of overrides for model configuration. + metrics (dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. Methods: - __call__(source=None, stream=False, **kwargs): - Alias for the predict method. - _new(cfg:str, verbose:bool=True) -> None: - Initializes a new model and infers the task type from the model definitions. - _load(weights:str, task:str='') -> None: - Initializes a new model and infers the task type from the model head. - _check_is_pytorch_model() -> None: - Raises TypeError if the model is not a PyTorch model. - reset() -> None: - Resets the model modules. - info(verbose:bool=False) -> None: - Logs the model info. - fuse() -> None: - Fuses the model for faster inference. - predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]: - Performs prediction using the YOLO model. + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + _get_hub_session: Retrieves or creates an Ultralytics HUB session. + is_triton_model: Checks if a model is a Triton Server model. + is_hub_model: Checks if a model is an Ultralytics HUB model. + _reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model. + _smart_load: Loads the appropriate module based on the model task. + task_map: Provides a mapping from model tasks to corresponding classes. - Returns: - list(ultralytics.engine.results.Results): The prediction results. + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + TypeError: If the model is not a PyTorch model when required. + AttributeError: If required attributes or methods are not implemented or available. + NotImplementedError: If a specific model task or mode is not supported. """ def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None: """ - Initializes the YOLO model. + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of model + sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several + important attributes of the model and prepares it for operations like training, prediction, or export. Args: - model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'. - task (Any, optional): Task type for the YOLO model. Defaults to None. - verbose (bool, optional): Whether to enable verbose mode. + model (Union[str, Path], optional): The path or model file to load or create. This can be a local + file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'. + task (Any, optional): The task type associated with the YOLO model, specifying its application domain. + Defaults to None. + verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent + operations. Defaults to False. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. """ super().__init__() self.callbacks = callbacks.get_default_callbacks() @@ -98,7 +136,22 @@ class Model(nn.Module): self.model_name = model def __call__(self, source=None, stream=False, **kwargs): - """Calls the predict() method with given arguments to perform object detection.""" + """ + An alias for the predict method, enabling the model instance to be callable. + + This method simplifies the process of making predictions by allowing the model instance to be called directly + with the required arguments for prediction. + + Args: + source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions. + Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None. + stream (bool, optional): If True, treats the input source as a continuous stream for predictions. + Defaults to False. + **kwargs (dict): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class. + """ return self.predict(source, stream, **kwargs) @staticmethod @@ -185,7 +238,19 @@ class Model(nn.Module): ) def reset_weights(self): - """Resets the model modules parameters to randomly initialized values, losing all training information.""" + """ + Resets the model parameters to randomly initialized values, effectively discarding all training information. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them + to be updated during training. + + Returns: + self (ultralytics.engine.model.Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ self._check_is_pytorch_model() for m in self.model.modules(): if hasattr(m, "reset_parameters"): @@ -195,42 +260,94 @@ class Model(nn.Module): return self def load(self, weights="yolov8n.pt"): - """Transfers parameters with matching names and shapes from 'weights' to model.""" + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'. + + Returns: + self (ultralytics.engine.model.Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ self._check_is_pytorch_model() if isinstance(weights, (str, Path)): weights, self.ckpt = attempt_load_one_weight(weights) self.model.load(weights) return self - def info(self, detailed=False, verbose=True): + def save(self, filename="model.pt"): """ - Logs model info. + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. Args: - detailed (bool): Show detailed information about model. - verbose (bool): Controls verbosity. + filename (str): The name of the file to save the model to. Defaults to 'model.pt'. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ + self._check_is_pytorch_model() + import torch + + torch.save(self.ckpt, filename) + + def info(self, detailed=False, verbose=True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments passed. + It can control the verbosity of the output. + + Args: + detailed (bool): If True, shows detailed information about the model. Defaults to False. + verbose (bool): If True, prints the information. If False, returns the information. Defaults to True. + + Returns: + (list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() return self.model.info(detailed=detailed, verbose=verbose) def fuse(self): - """Fuse PyTorch Conv2d and BatchNorm2d layers.""" + """ + Fuses Conv2d and BatchNorm2d layers in the model. + + This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed. + + Raises: + AssertionError: If the model is not a PyTorch model. + """ self._check_is_pytorch_model() self.model.fuse() def embed(self, source=None, stream=False, **kwargs): """ - Calls the predict() method and returns image embeddings. + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source. + It allows customization of the embedding process through various keyword arguments. Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. + source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings. + The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None. + stream (bool): If True, predictions are streamed. Defaults to False. + **kwargs (dict): Additional keyword arguments for configuring the embedding process. Returns: - (List[torch.Tensor]): A list of image embeddings. + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. """ if not kwargs.get("embed"): kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed @@ -238,18 +355,32 @@ class Model(nn.Module): def predict(self, source=None, stream=False, predictor=None, **kwargs): """ - Perform prediction using the YOLO model. + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. It also provides support for SAM-type models + through 'prompts'. + + The method sets up a new predictor if not already present and updates its arguments with each call. + It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it + is being called from the command line interface and adjusts its behavior accordingly, including setting defaults + for confidence threshold and saving behavior. Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - predictor (BasePredictor): Customized predictor. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. + source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions. + Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS. + stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False. + predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. Defaults to None. + **kwargs (dict): Additional keyword arguments for configuring the prediction process. These arguments allow + for further customization of the prediction behavior. Returns: - (List[ultralytics.engine.results.Results]): The prediction results. + (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class. + + Raises: + AttributeError: If the predictor is not properly set up. """ if source is None: source = ASSETS @@ -276,16 +407,28 @@ class Model(nn.Module): def track(self, source=None, stream=False, persist=False, **kwargs): """ - Perform object tracking on the input source using the registered trackers. + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It is + capable of handling different types of input sources such as file paths or video streams. The method supports + customization of the tracking process through various keyword arguments. It registers trackers if they are not + already present and optionally persists them based on the 'persist' flag. + + The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low + confidence predictions as input. The tracking mode is explicitly set in the keyword arguments. Args: - source (str, optional): The input source for object tracking. Can be a file path or a video stream. - stream (bool, optional): Whether the input source is a video stream. Defaults to False. - persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. - **kwargs (optional): Additional keyword arguments for the tracking process. + source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream. + stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False. + persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False. + **kwargs (dict): Additional keyword arguments for configuring the tracking process. These arguments allow + for further customization of the tracking behavior. Returns: - (List[ultralytics.engine.results.Results]): The tracking results. + (List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class. + + Raises: + AttributeError: If the predictor does not have registered trackers. """ if not hasattr(self.predictor, "trackers"): from ultralytics.trackers import register_tracker @@ -297,11 +440,28 @@ class Model(nn.Module): def val(self, validator=None, **kwargs): """ - Validate a model on a given dataset. + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for a range of customization through various + settings and configurations. It supports validation with a custom validator or the default validation approach. + The method combines default configurations, method-specific defaults, and user-provided arguments to configure + the validation process. After validation, it updates the model's metrics with the results obtained from the + validator. + + The method supports various arguments that allow customization of the validation process. For a comprehensive + list of all configurable options, users should refer to the 'configuration' section in the documentation. Args: - validator (BaseValidator): Customized validator. - **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs + validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If + None, the method uses a default validator. Defaults to None. + **kwargs (dict): Arbitrary keyword arguments representing the validation configuration. These arguments are + used to customize various aspects of the validation process. + + Returns: + (dict): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. """ custom = {"rect": True} # method defaults args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right @@ -313,10 +473,26 @@ class Model(nn.Module): def benchmark(self, **kwargs): """ - Benchmark a model on all export formats. + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured + using a combination of default configuration values, model-specific arguments, method-specific defaults, and + any additional user-provided keyword arguments. + + The method supports various arguments that allow customization of the benchmarking process, such as dataset + choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all + configurable options, users should refer to the 'configuration' section in the documentation. Args: - **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs + **kwargs (dict): Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. + + Returns: + (dict): A dictionary containing the results of the benchmarking process. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() from ultralytics.utils.benchmarks import benchmark @@ -335,10 +511,24 @@ class Model(nn.Module): def export(self, **kwargs): """ - Export model. + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. The combined arguments are used to configure export settings. + + The method supports a wide range of arguments to customize the export process. For a comprehensive list of all + possible arguments, refer to the 'configuration' section in the documentation. Args: - **kwargs : Any other args accepted by the Exporter. To see all args check 'configuration' section in docs. + **kwargs (dict): Arbitrary keyword arguments to customize the export process. These are combined with the + model's overrides and method defaults. + + Returns: + (object): The exported model in the specified format, or an object related to the export process. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() from .exporter import Exporter @@ -349,11 +539,31 @@ class Model(nn.Module): def train(self, trainer=None, **kwargs): """ - Trains the model on a given dataset. + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings and configurations. It supports + training with a custom trainer or the default training approach defined in the method. The method handles + different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and + updating model and configuration after training. + + When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training + arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. After + training, it updates the model and its configurations, and optionally attaches metrics. Args: - trainer (BaseTrainer, optional): Customized trainer. - **kwargs (Any): Any number of arguments representing the training configuration. + trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the + method uses a default trainer. Defaults to None. + **kwargs (dict): Arbitrary keyword arguments representing the training configuration. These arguments are + used to customize various aspects of the training process. + + Returns: + (dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. """ self._check_is_pytorch_model() if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model @@ -399,10 +609,24 @@ class Model(nn.Module): def tune(self, use_ray=False, iterations=10, *args, **kwargs): """ - Runs hyperparameter tuning, optionally using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args. + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args (list): Variable length argument list for additional arguments. + **kwargs (dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults. Returns: (dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() if use_ray: @@ -426,31 +650,81 @@ class Model(nn.Module): @property def names(self): - """Returns class names of the loaded model.""" + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. + + Returns: + (list | None): The class names of the model if available, otherwise None. + """ from ultralytics.nn.autobackend import check_class_names return check_class_names(self.model.names) if hasattr(self.model, "names") else None @property def device(self): - """Returns device if PyTorch model.""" + """ + Retrieves the device on which the model's parameters are allocated. + + This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models + that are instances of nn.Module. + + Returns: + (torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None. + """ return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None @property def transforms(self): - """Returns transform of the loaded model.""" + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + """ return self.model.transforms if hasattr(self.model, "transforms") else None def add_callback(self, event: str, func): - """Add a callback.""" + """ + Adds a callback function for a specified event. + + This method allows the user to register a custom callback function that is triggered on a specific event during + model training or inference. + + Args: + event (str): The name of the event to attach the callback to. + func (callable): The callback function to be registered. + + Raises: + ValueError: If the event name is not recognized. + """ self.callbacks[event].append(func) def clear_callback(self, event: str): - """Clear all event callbacks.""" + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + + Args: + event (str): The name of the event for which to clear the callbacks. + + Raises: + ValueError: If the event name is not recognized. + """ self.callbacks[event] = [] def reset_callbacks(self): - """Reset all registered callbacks.""" + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + added previously. + """ for event in callbacks.default_callbacks.keys(): self.callbacks[event] = [callbacks.default_callbacks[event][0]] diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 2aba0629..1739bad3 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -631,7 +631,7 @@ def torch_safe_load(weight): "ultralytics.yolo.data": "ultralytics.data", } ): # for legacy 8.0 Classify and Pose models - return torch.load(file, map_location="cpu"), file # load + ckpt = torch.load(file, map_location="cpu") except ModuleNotFoundError as e: # e.name is missing module name if e.name == "models": @@ -651,8 +651,17 @@ def torch_safe_load(weight): f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'" ) check_requirements(e.name) # install missing module + ckpt = torch.load(file, map_location="cpu") - return torch.load(file, map_location="cpu"), file # load + if not isinstance(ckpt, dict): + # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt") + LOGGER.warning( + f"WARNING âš ī¸ The file '{weight}' appears to be improperly saved or formatted. " + f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." + ) + ckpt = {"model": ckpt.model} + + return ckpt, file # load def attempt_load_weights(weights, device=None, inplace=True, fuse=False):