From b02bf58de6b03196ba3c404a2b59787894519159 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Tue, 15 Oct 2024 11:51:37 -0400
Subject: [PATCH] added attribution visualization with run_attribution.py

---
 .gitignore                                  |   1 +
 run_attribution.py                          |  34 +
 run_train.py                                |  83 +++
 run_val.py                                  |  96 +++
 ultralytics/engine/model.py                 |  80 +-
 ultralytics/engine/pgt_trainer.py           | 785 ++++++++++++++++++++
 ultralytics/engine/predictor.py             |   7 +-
 ultralytics/models/yolo/detect/__init__.py  |   3 +-
 ultralytics/models/yolo/detect/pgt_train.py | 144 ++++
 ultralytics/models/yolov10/model.py         |   3 +
 ultralytics/models/yolov10/pgt_train.py     |  21 +
 ultralytics/nn/tasks.py                     |  21 +-
 12 files changed, 1266 insertions(+), 12 deletions(-)
 create mode 100644 run_attribution.py
 create mode 100644 run_train.py
 create mode 100644 run_val.py
 create mode 100644 ultralytics/engine/pgt_trainer.py
 create mode 100644 ultralytics/models/yolo/detect/pgt_train.py
 create mode 100644 ultralytics/models/yolov10/pgt_train.py

diff --git a/.gitignore b/.gitignore
index 0854267a..1463e020 100644
--- a/.gitignore
+++ b/.gitignore
@@ -51,6 +51,7 @@ coverage.xml
 .hypothesis/
 .pytest_cache/
 mlruns/
+figures/
 
 # Translations
 *.mo
diff --git a/run_attribution.py b/run_attribution.py
new file mode 100644
index 00000000..8c87e592
--- /dev/null
+++ b/run_attribution.py
@@ -0,0 +1,34 @@
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+# from ultralytics import BaseTrainer
+# from ultralytics.engine.trainer import BaseTrainer
+import os
+
+# Set CUDA device (only needed for multi-gpu machines) 
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
+
+# model = YOLOv10()
+# model = YOLO()
+# If you want to finetune the model with pretrained weights, you could load the 
+# pretrained weights like below
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+model = YOLOv10('yolov10n.pt')
+
+model.train(data='coco.yaml', 
+            trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
+            # Add return_images as input parameter
+            epochs=500, batch=16, imgsz=640,
+            debug=True, # If debug = True, the attributions will be saved in the figures folder
+            )
+
+# Save the trained model
+model.save('yolov10_coco_trained.pt')
+
+# Evaluate the model on the validation set
+results = model.val(data='coco.yaml')
+
+# Print the evaluation results
+print(results)
\ No newline at end of file
diff --git a/run_train.py b/run_train.py
new file mode 100644
index 00000000..012d9d83
--- /dev/null
+++ b/run_train.py
@@ -0,0 +1,83 @@
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+# from ultralytics import BaseTrainer
+# from ultralytics.engine.trainer import BaseTrainer
+import os
+
+# Set CUDA device (only needed for multi-gpu machines) 
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
+
+model = YOLOv10()
+# model = YOLO()
+# If you want to finetune the model with pretrained weights, you could load the 
+# pretrained weights like below
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+# model = YOLOv10('yolov10m.pt')
+
+model.train(data='coco.yaml', 
+            # Add return_images as input parameter
+            epochs=500, batch=16, imgsz=640,
+            )
+
+# Save the trained model
+model.save('yolov10_coco_trained.pt')
+
+# Evaluate the model on the validation set
+results = model.val(data='coco.yaml')
+
+# Print the evaluation results
+print(results)
+
+# import torch
+# from torch.utils.data import DataLoader
+# from torchvision import datasets, transforms
+
+# # Define the transformation for the dataset
+# transform = transforms.Compose([
+#     transforms.Resize((640, 640)),
+#     transforms.ToTensor()
+# ])
+
+# # Load the COCO dataset
+# train_dataset = datasets.CocoDetection(root='data/nielseni6/coco/train2017', annFile='/data/nielseni6/coco/annotations/instances_train2017.json', transform=transform)
+# val_dataset = datasets.CocoDetection(root='data/nielseni6/coco/val2017', annFile='/data/nielseni6/coco/annotations/instances_val2017.json', transform=transform)
+
+# # Create data loaders
+# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
+# val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
+
+# model = YOLOv10()
+
+# # Define the optimizer
+# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+
+# # Training loop
+# for epoch in range(500):
+#     model.train()
+#     for images, targets in train_loader:
+#         images = images.to('cuda')
+#         targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
+#         loss = model(images, targets)
+#         loss.backward()
+#         optimizer.step()
+#         optimizer.zero_grad()
+
+#     # Validation loop
+#     model.eval()
+#     with torch.no_grad():
+#         for images, targets in val_loader:
+#             images = images.to('cuda')
+#             targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
+#             results = model(images, targets)
+
+# # Save the trained model
+# model.save('yolov10_coco_trained.pt')
+
+# # Evaluate the model on the validation set
+# results = model.val(data='coco.yaml')
+
+# # Print the evaluation results
+# print(results)
\ No newline at end of file
diff --git a/run_val.py b/run_val.py
new file mode 100644
index 00000000..6003d5ee
--- /dev/null
+++ b/run_val.py
@@ -0,0 +1,96 @@
+from ultralytics import YOLOv10
+import torch
+from PIL import Image
+from torchvision import transforms
+
+# Define the device
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# model = YOLOv10.from_pretrained('jameslahm/yolov10n')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt
+# model = YOLOv10('yolov10{n/s/m/b/l/x}.pt')
+model = YOLOv10('yolov10n.pt').to(device)
+
+# Load the image
+# path = '/home/nielseni6/PythonScripts/Github/yolov10/images/fat-dog.jpg'
+path = '/home/nielseni6/PythonScripts/Github/yolov10/images/The-Cardinal-Bird.jpg'
+image = Image.open(path)
+
+# Define the transformation to resize the image, convert it to a tensor, and normalize it
+transform = transforms.Compose([
+    transforms.Resize((640, 640)),
+    transforms.ToTensor(),
+    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# Apply the transformation
+image_tensor = transform(image)
+
+# Add a batch dimension
+image_tensor = image_tensor.unsqueeze(0).to(device)
+image_tensor = image_tensor.requires_grad_(True)
+
+
+# Predict for a specific image
+# results = model.predict(image_tensor, save=True)
+# model.requires_grad_(True)
+
+
+# for p in model.parameters():
+#     p.requires_grad = True
+results = model.predict(image_tensor, save=True)
+
+# Display the results
+for result in results:
+    print(result)
+
+# pred = results[0].boxes[0].conf
+
+# # Hook to store the activations
+# activations = {}
+
+# def get_activation(name):
+#     def hook(model, input, output):
+#         activations[name] = output
+#     return hook
+
+# # Register hooks for each layer you want to inspect
+# for name, layer in model.model.named_modules():
+#     layer.register_forward_hook(get_activation(name))
+
+# # Run the model to get activations
+# results = model.predict(image_tensor, save=True, visualize=True)
+
+# # # Print the activations
+# # for name, activation in activations.items():
+# #     print(f"Activation from layer {name}: {activation}")
+
+# # List activation names separately
+# print("\nActivation layer names:")
+# for name in activations.keys():
+#     print(name)
+# # pred.backward()
+
+# # Assuming 'model.23' is the layer of interest for bbox prediction and confidence
+# activation = activations['model.23']['one2one'][0]
+# act_23 = activations['model.23.cv3.2']
+# act_dfl = activations['model.23.dfl.conv']
+# act_conv = activations['model.0.conv']
+# act_act = activations['model.0.act']
+
+# # with torch.autograd.set_detect_anomaly(True):
+# #     pred.backward()
+# grad = torch.autograd.grad(act_23, im, grad_outputs=torch.ones_like(act_23), create_graph=True, retain_graph=True)[0]
+# # grad = torch.autograd.grad(pred, im, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
+# grad = torch.autograd.grad(activations['model.23']['one2one'][1][0], 
+#                            activations['model.23.one2one_cv3.2'], 
+#                            grad_outputs=torch.ones_like(activations['model.23']['one2one'][1][0]), 
+#                            create_graph=True, retain_graph=True)[0]
+
+# # Print the results
+# print(results)
+
+# model.val(data='coco.yaml', batch=256)
\ No newline at end of file
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index ef5c93c0..be87559c 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -387,6 +387,7 @@ class Model(nn.Module):
         source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
         stream: bool = False,
         predictor=None,
+        return_images: bool = False,
         **kwargs,
     ) -> list:
         """
@@ -438,7 +439,7 @@ class Model(nn.Module):
                 self.predictor.save_dir = get_save_dir(self.predictor.args)
         if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type models
             self.predictor.set_prompts(prompts)
-        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
+        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream, return_images=return_images)
 
     def track(
         self,
@@ -590,6 +591,81 @@ class Model(nn.Module):
         return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
 
     def train(
+        self,
+        trainer=None,
+        debug=False,
+        **kwargs,
+    ):
+        """
+        Trains the model using the specified dataset and training configuration.
+
+        This method facilitates model training with a range of customizable settings and configurations. It supports
+        training with a custom trainer or the default training approach defined in the method. The method handles
+        different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
+        updating model and configuration after training.
+
+        When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
+        arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
+        configurations, method-specific defaults, and user-provided arguments to configure the training process. After
+        training, it updates the model and its configurations, and optionally attaches metrics.
+
+        Args:
+            trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
+                method uses a default trainer. Defaults to None.
+            **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
+                used to customize various aspects of the training process.
+
+        Returns:
+            (dict | None): Training metrics if available and training is successful; otherwise, None.
+
+        Raises:
+            AssertionError: If the model is not a PyTorch model.
+            PermissionError: If there is a permission issue with the HUB session.
+            ModuleNotFoundError: If the HUB SDK is not installed.
+        """
+        self._check_is_pytorch_model()
+        if hasattr(self.session, "model") and self.session.model.id:  # Ultralytics HUB session with loaded model
+            if any(kwargs):
+                LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
+            kwargs = self.session.train_args  # overwrite kwargs
+
+        checks.check_pip_update_available()
+
+        overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
+        custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]}  # method defaults
+        args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right
+        if args.get("resume"):
+            args["resume"] = self.ckpt_path
+
+        self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
+        if not args.get("resume"):  # manually set model only if not resuming
+            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
+            self.model = self.trainer.model
+
+            if SETTINGS["hub"] is True and not self.session:
+                # Create a model in HUB
+                try:
+                    self.session = self._get_hub_session(self.model_name)
+                    if self.session:
+                        self.session.create_model(args)
+                        # Check model was created
+                        if not getattr(self.session.model, "id", None):
+                            self.session = None
+                except (PermissionError, ModuleNotFoundError):
+                    # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
+                    pass
+
+        self.trainer.hub_session = self.session  # attach optional HUB session
+        self.trainer.train(debug=debug)
+        # Update model and cfg after training
+        if RANK in (-1, 0):
+            ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
+            self.model, _ = attempt_load_one_weight(ckpt)
+            self.overrides = self.model.args
+            self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
+        return self.metrics
+
+    def train_pgt(
         self,
         trainer=None,
         **kwargs,
@@ -662,7 +738,7 @@ class Model(nn.Module):
             self.overrides = self.model.args
             self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
         return self.metrics
-
+    
     def tune(
         self,
         use_ray=False,
diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py
new file mode 100644
index 00000000..3d99ad41
--- /dev/null
+++ b/ultralytics/engine/pgt_trainer.py
@@ -0,0 +1,785 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+"""
+Train a model on a dataset.
+
+Usage:
+    $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
+"""
+
+import math
+import os
+import subprocess
+import time
+import warnings
+from copy import deepcopy
+from datetime import datetime, timedelta
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch import distributed as dist
+from torch import nn, optim
+
+import matplotlib.pyplot as plt
+import torchvision.transforms as T
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
+from ultralytics.utils import (
+    DEFAULT_CFG,
+    LOGGER,
+    RANK,
+    TQDM,
+    __version__,
+    callbacks,
+    clean_url,
+    colorstr,
+    emojis,
+    yaml_save,
+)
+from ultralytics.utils.autobatch import check_train_batch_size
+from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
+from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
+from ultralytics.utils.files import get_latest_run
+from ultralytics.utils.torch_utils import (
+    EarlyStopping,
+    ModelEMA,
+    de_parallel,
+    init_seeds,
+    one_cycle,
+    select_device,
+    strip_optimizer,
+)
+
+
+class PGTTrainer:
+    """
+    BaseTrainer.
+
+    A base class for creating trainers.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the trainer.
+        validator (BaseValidator): Validator instance.
+        model (nn.Module): Model instance.
+        callbacks (defaultdict): Dictionary of callbacks.
+        save_dir (Path): Directory to save results.
+        wdir (Path): Directory to save weights.
+        last (Path): Path to the last checkpoint.
+        best (Path): Path to the best checkpoint.
+        save_period (int): Save checkpoint every x epochs (disabled if < 1).
+        batch_size (int): Batch size for training.
+        epochs (int): Number of epochs to train for.
+        start_epoch (int): Starting epoch for training.
+        device (torch.device): Device to use for training.
+        amp (bool): Flag to enable AMP (Automatic Mixed Precision).
+        scaler (amp.GradScaler): Gradient scaler for AMP.
+        data (str): Path to data.
+        trainset (torch.utils.data.Dataset): Training dataset.
+        testset (torch.utils.data.Dataset): Testing dataset.
+        ema (nn.Module): EMA (Exponential Moving Average) of the model.
+        resume (bool): Resume training from a checkpoint.
+        lf (nn.Module): Loss function.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
+        best_fitness (float): The best fitness value achieved.
+        fitness (float): Current fitness value.
+        loss (float): Current loss value.
+        tloss (float): Total loss value.
+        loss_names (list): List of loss names.
+        csv (Path): Path to results CSV file.
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initializes the BaseTrainer class.
+
+        Args:
+            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
+            overrides (dict, optional): Configuration overrides. Defaults to None.
+        """
+        self.args = get_cfg(cfg, overrides)
+        self.check_resume(overrides)
+        self.device = select_device(self.args.device, self.args.batch)
+        self.validator = None
+        self.metrics = None
+        self.plots = {}
+        init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
+
+        # Dirs
+        self.save_dir = get_save_dir(self.args)
+        self.args.name = self.save_dir.name  # update name for loggers
+        self.wdir = self.save_dir / "weights"  # weights dir
+        if RANK in (-1, 0):
+            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
+            self.args.save_dir = str(self.save_dir)
+            yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
+        self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
+        self.save_period = self.args.save_period
+
+        self.batch_size = self.args.batch
+        self.epochs = self.args.epochs
+        self.start_epoch = 0
+        if RANK == -1:
+            print_args(vars(self.args))
+
+        # Device
+        if self.device.type in ("cpu", "mps"):
+            self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading
+
+        # Model and Dataset
+        self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt
+        try:
+            if self.args.task == "classify":
+                self.data = check_cls_dataset(self.args.data)
+            elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
+                "detect",
+                "segment",
+                "pose",
+                "obb",
+            ):
+                self.data = check_det_dataset(self.args.data)
+                if "yaml_file" in self.data:
+                    self.args.data = self.data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
+        except Exception as e:
+            raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
+
+        self.trainset, self.testset = self.get_dataset(self.data)
+        self.ema = None
+
+        # Optimization utils init
+        self.lf = None
+        self.scheduler = None
+
+        # Epoch level metrics
+        self.best_fitness = None
+        self.fitness = None
+        self.loss = None
+        self.tloss = None
+        self.loss_names = ["Loss"]
+        self.csv = self.save_dir / "results.csv"
+        self.plot_idx = [0, 1, 2]
+
+        # Callbacks
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        if RANK in (-1, 0):
+            callbacks.add_integration_callbacks(self)
+
+    def add_callback(self, event: str, callback):
+        """Appends the given callback."""
+        self.callbacks[event].append(callback)
+
+    def set_callback(self, event: str, callback):
+        """Overrides the existing callbacks with the given callback."""
+        self.callbacks[event] = [callback]
+
+    def run_callbacks(self, event: str):
+        """Run all existing callbacks associated with a particular event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def train(self, debug=False):
+        """Allow device='', device=None on Multi-GPU systems to default to device=0."""
+        if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
+            world_size = len(self.args.device.split(","))
+        elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
+            world_size = len(self.args.device)
+        elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
+            world_size = 1  # default to device 0
+        else:  # i.e. device='cpu' or 'mps'
+            world_size = 0
+
+        # Run subprocess if DDP training, else train normally
+        if world_size > 1 and "LOCAL_RANK" not in os.environ:
+            # Argument checks
+            if self.args.rect:
+                LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
+                self.args.rect = False
+            if self.args.batch == -1:
+                LOGGER.warning(
+                    "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
+                    "default 'batch=16'"
+                )
+                self.args.batch = 16
+
+            # Command
+            cmd, file = generate_ddp_command(world_size, self)
+            try:
+                LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
+                subprocess.run(cmd, check=True)
+            except Exception as e:
+                raise e
+            finally:
+                ddp_cleanup(self, str(file))
+
+        else:
+            self._do_train(world_size, debug=debug)
+
+    def _setup_scheduler(self):
+        """Initialize training learning rate scheduler."""
+        if self.args.cos_lr:
+            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
+        else:
+            self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf  # linear
+        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+
+    def _setup_ddp(self, world_size):
+        """Initializes and sets the DistributedDataParallel parameters for training."""
+        torch.cuda.set_device(RANK)
+        self.device = torch.device("cuda", RANK)
+        # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
+        os.environ["NCCL_BLOCKING_WAIT"] = "1"  # set to enforce timeout
+        dist.init_process_group(
+            backend="nccl" if dist.is_nccl_available() else "gloo",
+            timeout=timedelta(seconds=10800),  # 3 hours
+            rank=RANK,
+            world_size=world_size,
+        )
+
+    def _setup_train(self, world_size):
+        """Builds dataloaders and optimizer on correct rank process."""
+
+        # Model
+        self.run_callbacks("on_pretrain_routine_start")
+        ckpt = self.setup_model()
+        self.model = self.model.to(self.device)
+        self.set_model_attributes()
+
+        # Freeze layers
+        freeze_list = (
+            self.args.freeze
+            if isinstance(self.args.freeze, list)
+            else range(self.args.freeze)
+            if isinstance(self.args.freeze, int)
+            else []
+        )
+        always_freeze_names = [".dfl"]  # always freeze these layers
+        freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
+        for k, v in self.model.named_parameters():
+            # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
+            if any(x in k for x in freeze_layer_names):
+                LOGGER.info(f"Freezing layer '{k}'")
+                v.requires_grad = False
+            elif not v.requires_grad and v.dtype.is_floating_point:  # only floating point Tensor can require gradients
+                LOGGER.info(
+                    f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
+                    "See ultralytics.engine.trainer for customization of frozen layers."
+                )
+                v.requires_grad = True
+
+        # Check AMP
+        self.amp = torch.tensor(self.args.amp).to(self.device)  # True or False
+        if self.amp and RANK in (-1, 0):  # Single-GPU and DDP
+            callbacks_backup = callbacks.default_callbacks.copy()  # backup callbacks as check_amp() resets them
+            self.amp = torch.tensor(check_amp(self.model), device=self.device)
+            callbacks.default_callbacks = callbacks_backup  # restore callbacks
+        if RANK > -1 and world_size > 1:  # DDP
+            dist.broadcast(self.amp, src=0)  # broadcast the tensor from rank 0 to all other ranks (returns None)
+        self.amp = bool(self.amp)  # as boolean
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
+        if world_size > 1:
+            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
+
+        # Check imgsz
+        gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32)  # grid size (max stride)
+        self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
+        self.stride = gs  # for multiscale training
+
+        # Batch size
+        if self.batch_size == -1 and RANK == -1:  # single-GPU only, estimate best batch size
+            self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
+
+        # Dataloaders
+        batch_size = self.batch_size // max(world_size, 1)
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
+        if RANK in (-1, 0):
+            # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
+            self.test_loader = self.get_dataloader(
+                self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
+            )
+            self.validator = self.get_validator()
+            metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
+            self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
+            self.ema = ModelEMA(self.model)
+            if self.args.plots:
+                self.plot_training_labels()
+
+        # Optimizer
+        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
+        weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
+        iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
+        self.optimizer = self.build_optimizer(
+            model=self.model,
+            name=self.args.optimizer,
+            lr=self.args.lr0,
+            momentum=self.args.momentum,
+            decay=weight_decay,
+            iterations=iterations,
+        )
+        # Scheduler
+        self._setup_scheduler()
+        self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
+        self.resume_training(ckpt)
+        self.scheduler.last_epoch = self.start_epoch - 1  # do not move
+        self.run_callbacks("on_pretrain_routine_end")
+
+    def _do_train(self, world_size=1, debug=False):
+        """Train completed, evaluate and plot if specified by arguments."""
+        if world_size > 1:
+            self._setup_ddp(world_size)
+        self._setup_train(world_size)
+
+        nb = len(self.train_loader)  # number of batches
+        nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
+        last_opt_step = -1
+        self.epoch_time = None
+        self.epoch_time_start = time.time()
+        self.train_time_start = time.time()
+        self.run_callbacks("on_train_start")
+        LOGGER.info(
+            f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
+            f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
+            f"Logging results to {colorstr('bold', self.save_dir)}\n"
+            f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
+        )
+        if self.args.close_mosaic:
+            base_idx = (self.epochs - self.args.close_mosaic) * nb
+            self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
+        epoch = self.start_epoch
+        while True:
+            self.epoch = epoch
+            self.run_callbacks("on_train_epoch_start")
+            self.model.train()
+            if RANK != -1:
+                self.train_loader.sampler.set_epoch(epoch)
+            pbar = enumerate(self.train_loader)
+            # Update dataloader attributes (optional)
+            if epoch == (self.epochs - self.args.close_mosaic):
+                self._close_dataloader_mosaic()
+                self.train_loader.reset()
+
+            if RANK in (-1, 0):
+                LOGGER.info(self.progress_string())
+                pbar = TQDM(enumerate(self.train_loader), total=nb)
+            self.tloss = None
+            self.optimizer.zero_grad()
+            for i, batch in pbar:
+                self.run_callbacks("on_train_batch_start")
+                # Warmup
+                ni = i + nb * epoch
+                if ni <= nw:
+                    xi = [0, nw]  # x interp
+                    self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
+                    for j, x in enumerate(self.optimizer.param_groups):
+                        # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                        x["lr"] = np.interp(
+                            ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
+                        )
+                        if "momentum" in x:
+                            x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
+
+                # Forward
+                with torch.cuda.amp.autocast(self.amp):
+                    batch = self.preprocess_batch(batch)
+                    (self.loss, self.loss_items), images = self.model(batch, return_images=True)
+                    
+                    if debug and (i % 250):
+                        grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
+                        # Convert tensors to numpy arrays
+                        images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
+                        grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
+
+                        # Normalize grad for visualization
+                        grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
+
+                        for ix in range(images_np.shape[0]):
+                            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
+                            ax[0].imshow(images_np[i])
+                            ax[0].set_title('Image')
+                            ax[1].imshow(grad_np[i], cmap='jet')
+                            ax[1].set_title('Gradient')
+                            ax[2].imshow(images_np[i])
+                            ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
+                            ax[2].set_title('Overlay')
+
+                            save_dir_attr = "figures/attributions"
+                            if not os.path.exists(save_dir_attr):
+                                os.makedirs(save_dir_attr)
+                            plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
+                            plt.close(fig)
+                    
+                    if RANK != -1:
+                        self.loss *= world_size
+                    self.tloss = (
+                        (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
+                    )
+
+                # Backward
+                self.scaler.scale(self.loss).backward()
+
+                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
+                if ni - last_opt_step >= self.accumulate:
+                    self.optimizer_step()
+                    last_opt_step = ni
+
+                    # Timed stopping
+                    if self.args.time:
+                        self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
+                        if RANK != -1:  # if DDP training
+                            broadcast_list = [self.stop if RANK == 0 else None]
+                            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                            self.stop = broadcast_list[0]
+                        if self.stop:  # training time exceeded
+                            break
+
+                # Log
+                mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"  # (GB)
+                loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
+                losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
+                if RANK in (-1, 0):
+                    pbar.set_description(
+                        ("%11s" * 2 + "%11.4g" * (2 + loss_len))
+                        % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
+                    )
+                    self.run_callbacks("on_batch_end")
+                    if self.args.plots and ni in self.plot_idx:
+                        self.plot_training_samples(batch, ni)
+
+                self.run_callbacks("on_train_batch_end")
+
+            self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
+            self.run_callbacks("on_train_epoch_end")
+            if RANK in (-1, 0):
+                final_epoch = epoch + 1 == self.epochs
+                self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
+
+                # Validation
+                if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
+                    or final_epoch or self.stopper.possible_stop or self.stop:
+                    self.metrics, self.fitness = self.validate()
+                self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
+                self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
+                if self.args.time:
+                    self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
+
+                # Save model
+                if self.args.save or final_epoch:
+                    self.save_model()
+                    self.run_callbacks("on_model_save")
+
+            # Scheduler
+            t = time.time()
+            self.epoch_time = t - self.epoch_time_start
+            self.epoch_time_start = t
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()'
+                if self.args.time:
+                    mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
+                    self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
+                    self._setup_scheduler()
+                    self.scheduler.last_epoch = self.epoch  # do not move
+                    self.stop |= epoch >= self.epochs  # stop if exceeded epochs
+                self.scheduler.step()
+            self.run_callbacks("on_fit_epoch_end")
+            torch.cuda.empty_cache()  # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
+
+            # Early Stopping
+            if RANK != -1:  # if DDP training
+                broadcast_list = [self.stop if RANK == 0 else None]
+                dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                self.stop = broadcast_list[0]
+            if self.stop:
+                break  # must break all DDP ranks
+            epoch += 1
+
+        if RANK in (-1, 0):
+            # Do final val with best.pt
+            LOGGER.info(
+                f"\n{epoch - self.start_epoch + 1} epochs completed in "
+                f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
+            )
+            self.final_eval()
+            if self.args.plots:
+                self.plot_metrics()
+            self.run_callbacks("on_train_end")
+        torch.cuda.empty_cache()
+        self.run_callbacks("teardown")
+
+    def save_model(self):
+        """Save model training checkpoints with additional metadata."""
+        import pandas as pd  # scope for faster startup
+
+        metrics = {**self.metrics, **{"fitness": self.fitness}}
+        results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
+        ckpt = {
+            "epoch": self.epoch,
+            "best_fitness": self.best_fitness,
+            "model": deepcopy(de_parallel(self.model)).half(),
+            "ema": deepcopy(self.ema.ema).half(),
+            "updates": self.ema.updates,
+            "optimizer": self.optimizer.state_dict(),
+            "train_args": vars(self.args),  # save as dict
+            "train_metrics": metrics,
+            "train_results": results,
+            "date": datetime.now().isoformat(),
+            "version": __version__,
+            "license": "AGPL-3.0 (https://ultralytics.com/license)",
+            "docs": "https://docs.ultralytics.com",
+        }
+
+        # Save last and best
+        torch.save(ckpt, self.last)
+        if self.best_fitness == self.fitness:
+            torch.save(ckpt, self.best)
+        if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
+            torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
+
+    @staticmethod
+    def get_dataset(data):
+        """
+        Get train, val path from data dict if it exists.
+
+        Returns None if data format is not recognized.
+        """
+        return data["train"], data.get("val") or data.get("test")
+
+    def setup_model(self):
+        """Load/create/download model for any task."""
+        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
+            return
+
+        model, weights = self.model, None
+        ckpt = None
+        if str(model).endswith(".pt"):
+            weights, ckpt = attempt_load_one_weight(model)
+            cfg = ckpt["model"].yaml
+        else:
+            cfg = model
+        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
+        return ckpt
+
+    def optimizer_step(self):
+        """Perform a single step of the training optimizer with gradient clipping and EMA update."""
+        self.scaler.unscale_(self.optimizer)  # unscale gradients
+        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+        self.optimizer.zero_grad()
+        if self.ema:
+            self.ema.update(self.model)
+
+    def preprocess_batch(self, batch):
+        """Allows custom preprocessing model inputs and ground truths depending on task type."""
+        return batch
+
+    def validate(self):
+        """
+        Runs validation on test set using self.validator.
+
+        The returned dict is expected to contain "fitness" key.
+        """
+        metrics = self.validator(self)
+        fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
+        if not self.best_fitness or self.best_fitness < fitness:
+            self.best_fitness = fitness
+        return metrics, fitness
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Get model and raise NotImplementedError for loading cfg files."""
+        raise NotImplementedError("This task trainer doesn't support loading cfg files")
+
+    def get_validator(self):
+        """Returns a NotImplementedError when the get_validator function is called."""
+        raise NotImplementedError("get_validator function not implemented in trainer")
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Returns dataloader derived from torch.data.Dataloader."""
+        raise NotImplementedError("get_dataloader function not implemented in trainer")
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """Build dataset."""
+        raise NotImplementedError("build_dataset function not implemented in trainer")
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Note:
+            This is not needed for classification but necessary for segmentation & detection
+        """
+        return {"loss": loss_items} if loss_items is not None else ["loss"]
+
+    def set_model_attributes(self):
+        """To set or update model parameters before training."""
+        self.model.names = self.data["names"]
+
+    def build_targets(self, preds, targets):
+        """Builds target tensors for training YOLO model."""
+        pass
+
+    def progress_string(self):
+        """Returns a string describing training progress."""
+        return ""
+
+    # TODO: may need to put these following functions into callback
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples during YOLO training."""
+        pass
+
+    def plot_training_labels(self):
+        """Plots training labels for YOLO model."""
+        pass
+
+    def save_metrics(self, metrics):
+        """Saves training metrics to a CSV file."""
+        keys, vals = list(metrics.keys()), list(metrics.values())
+        n = len(metrics) + 1  # number of cols
+        s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n")  # header
+        with open(self.csv, "a") as f:
+            f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
+
+    def plot_metrics(self):
+        """Plot and display metrics visually."""
+        pass
+
+    def on_plot(self, name, data=None):
+        """Registers plots (e.g. to be consumed in callbacks)"""
+        path = Path(name)
+        self.plots[path] = {"data": data, "timestamp": time.time()}
+
+    def final_eval(self):
+        """Performs final evaluation and validation for object detection YOLO model."""
+        for f in self.last, self.best:
+            if f.exists():
+                strip_optimizer(f)  # strip optimizers
+                if f is self.best:
+                    LOGGER.info(f"\nValidating {f}...")
+                    self.validator.args.plots = self.args.plots
+                    self.metrics = self.validator(model=f)
+                    self.metrics.pop("fitness", None)
+                    self.run_callbacks("on_fit_epoch_end")
+
+    def check_resume(self, overrides):
+        """Check if resume checkpoint exists and update arguments accordingly."""
+        resume = self.args.resume
+        if resume:
+            try:
+                exists = isinstance(resume, (str, Path)) and Path(resume).exists()
+                last = Path(check_file(resume) if exists else get_latest_run())
+
+                # Check that resume data YAML exists, otherwise strip to force re-download of dataset
+                ckpt_args = attempt_load_weights(last).args
+                if not Path(ckpt_args["data"]).exists():
+                    ckpt_args["data"] = self.args.data
+
+                resume = True
+                self.args = get_cfg(ckpt_args)
+                self.args.model = self.args.resume = str(last)  # reinstate model
+                for k in "imgsz", "batch", "device":  # allow arg updates to reduce memory or update device on resume
+                    if k in overrides:
+                        setattr(self.args, k, overrides[k])
+
+            except Exception as e:
+                raise FileNotFoundError(
+                    "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
+                    "i.e. 'yolo train resume model=path/to/last.pt'"
+                ) from e
+        self.resume = resume
+
+    def resume_training(self, ckpt):
+        """Resume YOLO training from given epoch and best fitness."""
+        if ckpt is None or not self.resume:
+            return
+        best_fitness = 0.0
+        start_epoch = ckpt["epoch"] + 1
+        if ckpt["optimizer"] is not None:
+            self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
+            best_fitness = ckpt["best_fitness"]
+        if self.ema and ckpt.get("ema"):
+            self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
+            self.ema.updates = ckpt["updates"]
+        assert start_epoch > 0, (
+            f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
+            f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
+        )
+        LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
+        if self.epochs < start_epoch:
+            LOGGER.info(
+                f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
+            )
+            self.epochs += ckpt["epoch"]  # finetune additional epochs
+        self.best_fitness = best_fitness
+        self.start_epoch = start_epoch
+        if start_epoch > (self.epochs - self.args.close_mosaic):
+            self._close_dataloader_mosaic()
+
+    def _close_dataloader_mosaic(self):
+        """Update dataloaders to stop using mosaic augmentation."""
+        if hasattr(self.train_loader.dataset, "mosaic"):
+            self.train_loader.dataset.mosaic = False
+        if hasattr(self.train_loader.dataset, "close_mosaic"):
+            LOGGER.info("Closing dataloader mosaic")
+            self.train_loader.dataset.close_mosaic(hyp=self.args)
+
+    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
+        """
+        Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
+        weight decay, and number of iterations.
+
+        Args:
+            model (torch.nn.Module): The model for which to build an optimizer.
+            name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
+                based on the number of iterations. Default: 'auto'.
+            lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+            momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
+            decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+            iterations (float, optional): The number of iterations, which determines the optimizer if
+                name is 'auto'. Default: 1e5.
+
+        Returns:
+            (torch.optim.Optimizer): The constructed optimizer.
+        """
+
+        g = [], [], []  # optimizer parameter groups
+        bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
+        if name == "auto":
+            LOGGER.info(
+                f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+                f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+                f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+            )
+            nc = getattr(model, "nc", 10)  # number of classes
+            lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
+            name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
+            self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam
+
+        for module_name, module in model.named_modules():
+            for param_name, param in module.named_parameters(recurse=False):
+                fullname = f"{module_name}.{param_name}" if module_name else param_name
+                if "bias" in fullname:  # bias (no decay)
+                    g[2].append(param)
+                elif isinstance(module, bn):  # weight (no decay)
+                    g[1].append(param)
+                else:  # weight (with decay)
+                    g[0].append(param)
+
+        if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
+            optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+        elif name == "RMSProp":
+            optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
+        elif name == "SGD":
+            optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+        else:
+            raise NotImplementedError(
+                f"Optimizer '{name}' not found in list of available optimizers "
+                f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
+                "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
+            )
+
+        optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
+        optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
+        LOGGER.info(
+            f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
+            f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
+        )
+        return optimizer
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index 9ec803a7..62cc59f8 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -206,7 +206,7 @@ class BasePredictor:
         self.vid_writer = {}
 
     @smart_inference_mode()
-    def stream_inference(self, source=None, model=None, *args, **kwargs):
+    def stream_inference(self, source=None, model=None, return_images = False, *args, **kwargs):
         """Streams real-time inference on camera feed and saves results to file."""
         if self.args.verbose:
             LOGGER.info("")
@@ -243,6 +243,9 @@ class BasePredictor:
                 with profilers[0]:
                     im = self.preprocess(im0s)
 
+                if return_images:
+                    im = im.requires_grad_(True)
+
                 # Inference
                 with profilers[1]:
                     preds = self.inference(im, *args, **kwargs)
@@ -272,7 +275,7 @@ class BasePredictor:
                     LOGGER.info("\n".join(s))
 
                 self.run_callbacks("on_predict_batch_end")
-                yield from self.results
+                yield from (self.results, im)
 
         # Release assets
         for v in self.vid_writer.values():
diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py
index 5f3e62c1..b4cb6dce 100644
--- a/ultralytics/models/yolo/detect/__init__.py
+++ b/ultralytics/models/yolo/detect/__init__.py
@@ -1,7 +1,8 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 
 from .predict import DetectionPredictor
+from .pgt_train import PGTDetectionTrainer
 from .train import DetectionTrainer
 from .val import DetectionValidator
 
-__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
diff --git a/ultralytics/models/yolo/detect/pgt_train.py b/ultralytics/models/yolo/detect/pgt_train.py
new file mode 100644
index 00000000..be960193
--- /dev/null
+++ b/ultralytics/models/yolo/detect/pgt_train.py
@@ -0,0 +1,144 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import math
+import random
+from copy import copy
+
+import numpy as np
+import torch.nn as nn
+
+from ultralytics.data import build_dataloader, build_yolo_dataset
+from ultralytics.engine.trainer import BaseTrainer
+from ultralytics.engine.pgt_trainer import PGTTrainer
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import DetectionModel
+from ultralytics.utils import LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
+from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
+
+
+class PGTDetectionTrainer(PGTTrainer):
+    """
+    A class extending the BaseTrainer class for training based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.detect import DetectionTrainer
+
+        args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
+        trainer = DetectionTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Construct and return dataloader."""
+        assert mode in ["train", "val"]
+        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+            dataset = self.build_dataset(dataset_path, mode, batch_size)
+        shuffle = mode == "train"
+        if getattr(dataset, "rect", False) and shuffle:
+            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
+            shuffle = False
+        workers = self.args.workers if mode == "train" else self.args.workers * 2
+        return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader
+
+    def preprocess_batch(self, batch):
+        """Preprocesses a batch of images by scaling and converting to float."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
+        if self.args.multi_scale:
+            imgs = batch["img"]
+            sz = (
+                random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
+                // self.stride
+                * self.stride
+            )  # size
+            sf = sz / max(imgs.shape[2:])  # scale factor
+            if sf != 1:
+                ns = [
+                    math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
+                ]  # new shape (stretched to gs-multiple)
+                imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
+            batch["img"] = imgs
+        return batch
+
+    def set_model_attributes(self):
+        """Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
+        # self.args.box *= 3 / nl  # scale to layers
+        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
+        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
+        self.model.nc = self.data["nc"]  # attach number of classes to model
+        self.model.names = self.data["names"]  # attach class names to model
+        self.model.args = self.args  # attach hyperparameters to model
+        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return a YOLO detection model."""
+        model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
+
+    def get_validator(self):
+        """Returns a DetectionValidator for YOLO model validation."""
+        self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+        return yolo.detect.DetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Not needed for classification but necessary for segmentation & detection
+        """
+        keys = [f"{prefix}/{x}" for x in self.loss_names]
+        if loss_items is not None:
+            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
+            return dict(zip(keys, loss_items))
+        else:
+            return keys
+
+    def progress_string(self):
+        """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
+        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+            "Epoch",
+            "GPU_mem",
+            *self.loss_names,
+            "Instances",
+            "Size",
+        )
+
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples with their annotations."""
+        plot_images(
+            images=batch["img"],
+            batch_idx=batch["batch_idx"],
+            cls=batch["cls"].squeeze(-1),
+            bboxes=batch["bboxes"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots metrics from a CSV file."""
+        plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png
+
+    def plot_training_labels(self):
+        """Create a labeled training plot of the YOLO model."""
+        boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
+        cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
+        plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py
index 09592c83..6e29c3d0 100644
--- a/ultralytics/models/yolov10/model.py
+++ b/ultralytics/models/yolov10/model.py
@@ -3,6 +3,8 @@ from ultralytics.nn.tasks import YOLOv10DetectionModel
 from .val import YOLOv10DetectionValidator
 from .predict import YOLOv10DetectionPredictor
 from .train import YOLOv10DetectionTrainer
+from .pgt_train import YOLOv10PGTDetectionTrainer
+# from .pgt_trainer import YOLOv10DetectionTrainer
 
 from huggingface_hub import PyTorchModelHubMixin
 from .card import card_template_text
@@ -30,6 +32,7 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
             "detect": {
                 "model": YOLOv10DetectionModel,
                 "trainer": YOLOv10DetectionTrainer,
+                "pgt_trainer": YOLOv10PGTDetectionTrainer,
                 "validator": YOLOv10DetectionValidator,
                 "predictor": YOLOv10DetectionPredictor,
             },
diff --git a/ultralytics/models/yolov10/pgt_train.py b/ultralytics/models/yolov10/pgt_train.py
new file mode 100644
index 00000000..1ce70145
--- /dev/null
+++ b/ultralytics/models/yolov10/pgt_train.py
@@ -0,0 +1,21 @@
+from ultralytics.models.yolo.detect import DetectionTrainer
+from ultralytics.models.yolo.detect import PGTDetectionTrainer
+from .val import YOLOv10DetectionValidator
+from .model import YOLOv10DetectionModel
+from copy import copy
+from ultralytics.utils import RANK
+
+class YOLOv10PGTDetectionTrainer(PGTDetectionTrainer):
+    def get_validator(self):
+        """Returns a DetectionValidator for YOLO model validation."""
+        self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", 
+        return YOLOv10DetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return a YOLO detection model."""
+        model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 268bd125..00fbadff 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -93,7 +93,7 @@ class BaseModel(nn.Module):
             return self.loss(x, *args, **kwargs)
         return self.predict(x, *args, **kwargs)
 
-    def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
+    def predict(self, x, profile=False, visualize=False, augment=False, embed=None, return_images=False):
         """
         Perform a forward pass through the network.
 
@@ -107,9 +107,12 @@ class BaseModel(nn.Module):
         Returns:
             (torch.Tensor): The last output of the model.
         """
+        if return_images:
+            x = x.requires_grad_(True)
         if augment:
             return self._predict_augment(x)
-        return self._predict_once(x, profile, visualize, embed)
+        out = self._predict_once(x, profile, visualize, embed)
+        return (out, x) if return_images else out
 
     def _predict_once(self, x, profile=False, visualize=False, embed=None):
         """
@@ -140,13 +143,13 @@ class BaseModel(nn.Module):
                     return torch.unbind(torch.cat(embeddings, 1), dim=0)
         return x
 
-    def _predict_augment(self, x):
+    def _predict_augment(self, x, *args, **kwargs):
         """Perform augmentations on input image x and return augmented inference."""
         LOGGER.warning(
             f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
             f"Reverting to single-scale inference instead."
         )
-        return self._predict_once(x)
+        return self._predict_once(x, *args, **kwargs)
 
     def _profile_one_layer(self, m, x, dt):
         """
@@ -260,7 +263,7 @@ class BaseModel(nn.Module):
         if verbose:
             LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
 
-    def loss(self, batch, preds=None):
+    def loss(self, batch, preds=None, return_images=False):
         """
         Compute loss.
 
@@ -271,8 +274,12 @@ class BaseModel(nn.Module):
         if not hasattr(self, "criterion"):
             self.criterion = self.init_criterion()
 
-        preds = self.forward(batch["img"]) if preds is None else preds
-        return self.criterion(preds, batch)
+        preds = self.forward(batch["img"], return_images=return_images) if preds is None else preds
+        if return_images:
+            preds, im = preds
+        loss = self.criterion(preds, batch)
+        out = loss if not return_images else (loss, im)
+        return out
 
     def init_criterion(self):
         """Initialize the loss criterion for the BaseModel."""