From b02bf58de6b03196ba3c404a2b59787894519159 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Tue, 15 Oct 2024 11:51:37 -0400
Subject: [PATCH 01/12] added attribution visualization with run_attribution.py

---
 .gitignore                                  |   1 +
 run_attribution.py                          |  34 +
 run_train.py                                |  83 +++
 run_val.py                                  |  96 +++
 ultralytics/engine/model.py                 |  80 +-
 ultralytics/engine/pgt_trainer.py           | 785 ++++++++++++++++++++
 ultralytics/engine/predictor.py             |   7 +-
 ultralytics/models/yolo/detect/__init__.py  |   3 +-
 ultralytics/models/yolo/detect/pgt_train.py | 144 ++++
 ultralytics/models/yolov10/model.py         |   3 +
 ultralytics/models/yolov10/pgt_train.py     |  21 +
 ultralytics/nn/tasks.py                     |  21 +-
 12 files changed, 1266 insertions(+), 12 deletions(-)
 create mode 100644 run_attribution.py
 create mode 100644 run_train.py
 create mode 100644 run_val.py
 create mode 100644 ultralytics/engine/pgt_trainer.py
 create mode 100644 ultralytics/models/yolo/detect/pgt_train.py
 create mode 100644 ultralytics/models/yolov10/pgt_train.py

diff --git a/.gitignore b/.gitignore
index 0854267a..1463e020 100644
--- a/.gitignore
+++ b/.gitignore
@@ -51,6 +51,7 @@ coverage.xml
 .hypothesis/
 .pytest_cache/
 mlruns/
+figures/
 
 # Translations
 *.mo
diff --git a/run_attribution.py b/run_attribution.py
new file mode 100644
index 00000000..8c87e592
--- /dev/null
+++ b/run_attribution.py
@@ -0,0 +1,34 @@
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+# from ultralytics import BaseTrainer
+# from ultralytics.engine.trainer import BaseTrainer
+import os
+
+# Set CUDA device (only needed for multi-gpu machines) 
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
+
+# model = YOLOv10()
+# model = YOLO()
+# If you want to finetune the model with pretrained weights, you could load the 
+# pretrained weights like below
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+model = YOLOv10('yolov10n.pt')
+
+model.train(data='coco.yaml', 
+            trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
+            # Add return_images as input parameter
+            epochs=500, batch=16, imgsz=640,
+            debug=True, # If debug = True, the attributions will be saved in the figures folder
+            )
+
+# Save the trained model
+model.save('yolov10_coco_trained.pt')
+
+# Evaluate the model on the validation set
+results = model.val(data='coco.yaml')
+
+# Print the evaluation results
+print(results)
\ No newline at end of file
diff --git a/run_train.py b/run_train.py
new file mode 100644
index 00000000..012d9d83
--- /dev/null
+++ b/run_train.py
@@ -0,0 +1,83 @@
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+# from ultralytics import BaseTrainer
+# from ultralytics.engine.trainer import BaseTrainer
+import os
+
+# Set CUDA device (only needed for multi-gpu machines) 
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
+
+model = YOLOv10()
+# model = YOLO()
+# If you want to finetune the model with pretrained weights, you could load the 
+# pretrained weights like below
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+# model = YOLOv10('yolov10m.pt')
+
+model.train(data='coco.yaml', 
+            # Add return_images as input parameter
+            epochs=500, batch=16, imgsz=640,
+            )
+
+# Save the trained model
+model.save('yolov10_coco_trained.pt')
+
+# Evaluate the model on the validation set
+results = model.val(data='coco.yaml')
+
+# Print the evaluation results
+print(results)
+
+# import torch
+# from torch.utils.data import DataLoader
+# from torchvision import datasets, transforms
+
+# # Define the transformation for the dataset
+# transform = transforms.Compose([
+#     transforms.Resize((640, 640)),
+#     transforms.ToTensor()
+# ])
+
+# # Load the COCO dataset
+# train_dataset = datasets.CocoDetection(root='data/nielseni6/coco/train2017', annFile='/data/nielseni6/coco/annotations/instances_train2017.json', transform=transform)
+# val_dataset = datasets.CocoDetection(root='data/nielseni6/coco/val2017', annFile='/data/nielseni6/coco/annotations/instances_val2017.json', transform=transform)
+
+# # Create data loaders
+# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
+# val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
+
+# model = YOLOv10()
+
+# # Define the optimizer
+# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+
+# # Training loop
+# for epoch in range(500):
+#     model.train()
+#     for images, targets in train_loader:
+#         images = images.to('cuda')
+#         targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
+#         loss = model(images, targets)
+#         loss.backward()
+#         optimizer.step()
+#         optimizer.zero_grad()
+
+#     # Validation loop
+#     model.eval()
+#     with torch.no_grad():
+#         for images, targets in val_loader:
+#             images = images.to('cuda')
+#             targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
+#             results = model(images, targets)
+
+# # Save the trained model
+# model.save('yolov10_coco_trained.pt')
+
+# # Evaluate the model on the validation set
+# results = model.val(data='coco.yaml')
+
+# # Print the evaluation results
+# print(results)
\ No newline at end of file
diff --git a/run_val.py b/run_val.py
new file mode 100644
index 00000000..6003d5ee
--- /dev/null
+++ b/run_val.py
@@ -0,0 +1,96 @@
+from ultralytics import YOLOv10
+import torch
+from PIL import Image
+from torchvision import transforms
+
+# Define the device
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+# model = YOLOv10.from_pretrained('jameslahm/yolov10n')
+# or
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt
+# model = YOLOv10('yolov10{n/s/m/b/l/x}.pt')
+model = YOLOv10('yolov10n.pt').to(device)
+
+# Load the image
+# path = '/home/nielseni6/PythonScripts/Github/yolov10/images/fat-dog.jpg'
+path = '/home/nielseni6/PythonScripts/Github/yolov10/images/The-Cardinal-Bird.jpg'
+image = Image.open(path)
+
+# Define the transformation to resize the image, convert it to a tensor, and normalize it
+transform = transforms.Compose([
+    transforms.Resize((640, 640)),
+    transforms.ToTensor(),
+    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+])
+
+# Apply the transformation
+image_tensor = transform(image)
+
+# Add a batch dimension
+image_tensor = image_tensor.unsqueeze(0).to(device)
+image_tensor = image_tensor.requires_grad_(True)
+
+
+# Predict for a specific image
+# results = model.predict(image_tensor, save=True)
+# model.requires_grad_(True)
+
+
+# for p in model.parameters():
+#     p.requires_grad = True
+results = model.predict(image_tensor, save=True)
+
+# Display the results
+for result in results:
+    print(result)
+
+# pred = results[0].boxes[0].conf
+
+# # Hook to store the activations
+# activations = {}
+
+# def get_activation(name):
+#     def hook(model, input, output):
+#         activations[name] = output
+#     return hook
+
+# # Register hooks for each layer you want to inspect
+# for name, layer in model.model.named_modules():
+#     layer.register_forward_hook(get_activation(name))
+
+# # Run the model to get activations
+# results = model.predict(image_tensor, save=True, visualize=True)
+
+# # # Print the activations
+# # for name, activation in activations.items():
+# #     print(f"Activation from layer {name}: {activation}")
+
+# # List activation names separately
+# print("\nActivation layer names:")
+# for name in activations.keys():
+#     print(name)
+# # pred.backward()
+
+# # Assuming 'model.23' is the layer of interest for bbox prediction and confidence
+# activation = activations['model.23']['one2one'][0]
+# act_23 = activations['model.23.cv3.2']
+# act_dfl = activations['model.23.dfl.conv']
+# act_conv = activations['model.0.conv']
+# act_act = activations['model.0.act']
+
+# # with torch.autograd.set_detect_anomaly(True):
+# #     pred.backward()
+# grad = torch.autograd.grad(act_23, im, grad_outputs=torch.ones_like(act_23), create_graph=True, retain_graph=True)[0]
+# # grad = torch.autograd.grad(pred, im, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
+# grad = torch.autograd.grad(activations['model.23']['one2one'][1][0], 
+#                            activations['model.23.one2one_cv3.2'], 
+#                            grad_outputs=torch.ones_like(activations['model.23']['one2one'][1][0]), 
+#                            create_graph=True, retain_graph=True)[0]
+
+# # Print the results
+# print(results)
+
+# model.val(data='coco.yaml', batch=256)
\ No newline at end of file
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index ef5c93c0..be87559c 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -387,6 +387,7 @@ class Model(nn.Module):
         source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
         stream: bool = False,
         predictor=None,
+        return_images: bool = False,
         **kwargs,
     ) -> list:
         """
@@ -438,7 +439,7 @@ class Model(nn.Module):
                 self.predictor.save_dir = get_save_dir(self.predictor.args)
         if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type models
             self.predictor.set_prompts(prompts)
-        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
+        return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream, return_images=return_images)
 
     def track(
         self,
@@ -590,6 +591,81 @@ class Model(nn.Module):
         return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
 
     def train(
+        self,
+        trainer=None,
+        debug=False,
+        **kwargs,
+    ):
+        """
+        Trains the model using the specified dataset and training configuration.
+
+        This method facilitates model training with a range of customizable settings and configurations. It supports
+        training with a custom trainer or the default training approach defined in the method. The method handles
+        different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
+        updating model and configuration after training.
+
+        When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
+        arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
+        configurations, method-specific defaults, and user-provided arguments to configure the training process. After
+        training, it updates the model and its configurations, and optionally attaches metrics.
+
+        Args:
+            trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
+                method uses a default trainer. Defaults to None.
+            **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
+                used to customize various aspects of the training process.
+
+        Returns:
+            (dict | None): Training metrics if available and training is successful; otherwise, None.
+
+        Raises:
+            AssertionError: If the model is not a PyTorch model.
+            PermissionError: If there is a permission issue with the HUB session.
+            ModuleNotFoundError: If the HUB SDK is not installed.
+        """
+        self._check_is_pytorch_model()
+        if hasattr(self.session, "model") and self.session.model.id:  # Ultralytics HUB session with loaded model
+            if any(kwargs):
+                LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
+            kwargs = self.session.train_args  # overwrite kwargs
+
+        checks.check_pip_update_available()
+
+        overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
+        custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]}  # method defaults
+        args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right
+        if args.get("resume"):
+            args["resume"] = self.ckpt_path
+
+        self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
+        if not args.get("resume"):  # manually set model only if not resuming
+            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
+            self.model = self.trainer.model
+
+            if SETTINGS["hub"] is True and not self.session:
+                # Create a model in HUB
+                try:
+                    self.session = self._get_hub_session(self.model_name)
+                    if self.session:
+                        self.session.create_model(args)
+                        # Check model was created
+                        if not getattr(self.session.model, "id", None):
+                            self.session = None
+                except (PermissionError, ModuleNotFoundError):
+                    # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
+                    pass
+
+        self.trainer.hub_session = self.session  # attach optional HUB session
+        self.trainer.train(debug=debug)
+        # Update model and cfg after training
+        if RANK in (-1, 0):
+            ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
+            self.model, _ = attempt_load_one_weight(ckpt)
+            self.overrides = self.model.args
+            self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
+        return self.metrics
+
+    def train_pgt(
         self,
         trainer=None,
         **kwargs,
@@ -662,7 +738,7 @@ class Model(nn.Module):
             self.overrides = self.model.args
             self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
         return self.metrics
-
+    
     def tune(
         self,
         use_ray=False,
diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py
new file mode 100644
index 00000000..3d99ad41
--- /dev/null
+++ b/ultralytics/engine/pgt_trainer.py
@@ -0,0 +1,785 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+"""
+Train a model on a dataset.
+
+Usage:
+    $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
+"""
+
+import math
+import os
+import subprocess
+import time
+import warnings
+from copy import deepcopy
+from datetime import datetime, timedelta
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch import distributed as dist
+from torch import nn, optim
+
+import matplotlib.pyplot as plt
+import torchvision.transforms as T
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
+from ultralytics.utils import (
+    DEFAULT_CFG,
+    LOGGER,
+    RANK,
+    TQDM,
+    __version__,
+    callbacks,
+    clean_url,
+    colorstr,
+    emojis,
+    yaml_save,
+)
+from ultralytics.utils.autobatch import check_train_batch_size
+from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
+from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
+from ultralytics.utils.files import get_latest_run
+from ultralytics.utils.torch_utils import (
+    EarlyStopping,
+    ModelEMA,
+    de_parallel,
+    init_seeds,
+    one_cycle,
+    select_device,
+    strip_optimizer,
+)
+
+
+class PGTTrainer:
+    """
+    BaseTrainer.
+
+    A base class for creating trainers.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the trainer.
+        validator (BaseValidator): Validator instance.
+        model (nn.Module): Model instance.
+        callbacks (defaultdict): Dictionary of callbacks.
+        save_dir (Path): Directory to save results.
+        wdir (Path): Directory to save weights.
+        last (Path): Path to the last checkpoint.
+        best (Path): Path to the best checkpoint.
+        save_period (int): Save checkpoint every x epochs (disabled if < 1).
+        batch_size (int): Batch size for training.
+        epochs (int): Number of epochs to train for.
+        start_epoch (int): Starting epoch for training.
+        device (torch.device): Device to use for training.
+        amp (bool): Flag to enable AMP (Automatic Mixed Precision).
+        scaler (amp.GradScaler): Gradient scaler for AMP.
+        data (str): Path to data.
+        trainset (torch.utils.data.Dataset): Training dataset.
+        testset (torch.utils.data.Dataset): Testing dataset.
+        ema (nn.Module): EMA (Exponential Moving Average) of the model.
+        resume (bool): Resume training from a checkpoint.
+        lf (nn.Module): Loss function.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
+        best_fitness (float): The best fitness value achieved.
+        fitness (float): Current fitness value.
+        loss (float): Current loss value.
+        tloss (float): Total loss value.
+        loss_names (list): List of loss names.
+        csv (Path): Path to results CSV file.
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initializes the BaseTrainer class.
+
+        Args:
+            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
+            overrides (dict, optional): Configuration overrides. Defaults to None.
+        """
+        self.args = get_cfg(cfg, overrides)
+        self.check_resume(overrides)
+        self.device = select_device(self.args.device, self.args.batch)
+        self.validator = None
+        self.metrics = None
+        self.plots = {}
+        init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
+
+        # Dirs
+        self.save_dir = get_save_dir(self.args)
+        self.args.name = self.save_dir.name  # update name for loggers
+        self.wdir = self.save_dir / "weights"  # weights dir
+        if RANK in (-1, 0):
+            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
+            self.args.save_dir = str(self.save_dir)
+            yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
+        self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
+        self.save_period = self.args.save_period
+
+        self.batch_size = self.args.batch
+        self.epochs = self.args.epochs
+        self.start_epoch = 0
+        if RANK == -1:
+            print_args(vars(self.args))
+
+        # Device
+        if self.device.type in ("cpu", "mps"):
+            self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading
+
+        # Model and Dataset
+        self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt
+        try:
+            if self.args.task == "classify":
+                self.data = check_cls_dataset(self.args.data)
+            elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
+                "detect",
+                "segment",
+                "pose",
+                "obb",
+            ):
+                self.data = check_det_dataset(self.args.data)
+                if "yaml_file" in self.data:
+                    self.args.data = self.data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
+        except Exception as e:
+            raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
+
+        self.trainset, self.testset = self.get_dataset(self.data)
+        self.ema = None
+
+        # Optimization utils init
+        self.lf = None
+        self.scheduler = None
+
+        # Epoch level metrics
+        self.best_fitness = None
+        self.fitness = None
+        self.loss = None
+        self.tloss = None
+        self.loss_names = ["Loss"]
+        self.csv = self.save_dir / "results.csv"
+        self.plot_idx = [0, 1, 2]
+
+        # Callbacks
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        if RANK in (-1, 0):
+            callbacks.add_integration_callbacks(self)
+
+    def add_callback(self, event: str, callback):
+        """Appends the given callback."""
+        self.callbacks[event].append(callback)
+
+    def set_callback(self, event: str, callback):
+        """Overrides the existing callbacks with the given callback."""
+        self.callbacks[event] = [callback]
+
+    def run_callbacks(self, event: str):
+        """Run all existing callbacks associated with a particular event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def train(self, debug=False):
+        """Allow device='', device=None on Multi-GPU systems to default to device=0."""
+        if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
+            world_size = len(self.args.device.split(","))
+        elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
+            world_size = len(self.args.device)
+        elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
+            world_size = 1  # default to device 0
+        else:  # i.e. device='cpu' or 'mps'
+            world_size = 0
+
+        # Run subprocess if DDP training, else train normally
+        if world_size > 1 and "LOCAL_RANK" not in os.environ:
+            # Argument checks
+            if self.args.rect:
+                LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
+                self.args.rect = False
+            if self.args.batch == -1:
+                LOGGER.warning(
+                    "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
+                    "default 'batch=16'"
+                )
+                self.args.batch = 16
+
+            # Command
+            cmd, file = generate_ddp_command(world_size, self)
+            try:
+                LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
+                subprocess.run(cmd, check=True)
+            except Exception as e:
+                raise e
+            finally:
+                ddp_cleanup(self, str(file))
+
+        else:
+            self._do_train(world_size, debug=debug)
+
+    def _setup_scheduler(self):
+        """Initialize training learning rate scheduler."""
+        if self.args.cos_lr:
+            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
+        else:
+            self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf  # linear
+        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+
+    def _setup_ddp(self, world_size):
+        """Initializes and sets the DistributedDataParallel parameters for training."""
+        torch.cuda.set_device(RANK)
+        self.device = torch.device("cuda", RANK)
+        # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
+        os.environ["NCCL_BLOCKING_WAIT"] = "1"  # set to enforce timeout
+        dist.init_process_group(
+            backend="nccl" if dist.is_nccl_available() else "gloo",
+            timeout=timedelta(seconds=10800),  # 3 hours
+            rank=RANK,
+            world_size=world_size,
+        )
+
+    def _setup_train(self, world_size):
+        """Builds dataloaders and optimizer on correct rank process."""
+
+        # Model
+        self.run_callbacks("on_pretrain_routine_start")
+        ckpt = self.setup_model()
+        self.model = self.model.to(self.device)
+        self.set_model_attributes()
+
+        # Freeze layers
+        freeze_list = (
+            self.args.freeze
+            if isinstance(self.args.freeze, list)
+            else range(self.args.freeze)
+            if isinstance(self.args.freeze, int)
+            else []
+        )
+        always_freeze_names = [".dfl"]  # always freeze these layers
+        freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
+        for k, v in self.model.named_parameters():
+            # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
+            if any(x in k for x in freeze_layer_names):
+                LOGGER.info(f"Freezing layer '{k}'")
+                v.requires_grad = False
+            elif not v.requires_grad and v.dtype.is_floating_point:  # only floating point Tensor can require gradients
+                LOGGER.info(
+                    f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
+                    "See ultralytics.engine.trainer for customization of frozen layers."
+                )
+                v.requires_grad = True
+
+        # Check AMP
+        self.amp = torch.tensor(self.args.amp).to(self.device)  # True or False
+        if self.amp and RANK in (-1, 0):  # Single-GPU and DDP
+            callbacks_backup = callbacks.default_callbacks.copy()  # backup callbacks as check_amp() resets them
+            self.amp = torch.tensor(check_amp(self.model), device=self.device)
+            callbacks.default_callbacks = callbacks_backup  # restore callbacks
+        if RANK > -1 and world_size > 1:  # DDP
+            dist.broadcast(self.amp, src=0)  # broadcast the tensor from rank 0 to all other ranks (returns None)
+        self.amp = bool(self.amp)  # as boolean
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
+        if world_size > 1:
+            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
+
+        # Check imgsz
+        gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32)  # grid size (max stride)
+        self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
+        self.stride = gs  # for multiscale training
+
+        # Batch size
+        if self.batch_size == -1 and RANK == -1:  # single-GPU only, estimate best batch size
+            self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
+
+        # Dataloaders
+        batch_size = self.batch_size // max(world_size, 1)
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
+        if RANK in (-1, 0):
+            # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
+            self.test_loader = self.get_dataloader(
+                self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
+            )
+            self.validator = self.get_validator()
+            metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
+            self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
+            self.ema = ModelEMA(self.model)
+            if self.args.plots:
+                self.plot_training_labels()
+
+        # Optimizer
+        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
+        weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
+        iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
+        self.optimizer = self.build_optimizer(
+            model=self.model,
+            name=self.args.optimizer,
+            lr=self.args.lr0,
+            momentum=self.args.momentum,
+            decay=weight_decay,
+            iterations=iterations,
+        )
+        # Scheduler
+        self._setup_scheduler()
+        self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
+        self.resume_training(ckpt)
+        self.scheduler.last_epoch = self.start_epoch - 1  # do not move
+        self.run_callbacks("on_pretrain_routine_end")
+
+    def _do_train(self, world_size=1, debug=False):
+        """Train completed, evaluate and plot if specified by arguments."""
+        if world_size > 1:
+            self._setup_ddp(world_size)
+        self._setup_train(world_size)
+
+        nb = len(self.train_loader)  # number of batches
+        nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
+        last_opt_step = -1
+        self.epoch_time = None
+        self.epoch_time_start = time.time()
+        self.train_time_start = time.time()
+        self.run_callbacks("on_train_start")
+        LOGGER.info(
+            f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
+            f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
+            f"Logging results to {colorstr('bold', self.save_dir)}\n"
+            f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
+        )
+        if self.args.close_mosaic:
+            base_idx = (self.epochs - self.args.close_mosaic) * nb
+            self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
+        epoch = self.start_epoch
+        while True:
+            self.epoch = epoch
+            self.run_callbacks("on_train_epoch_start")
+            self.model.train()
+            if RANK != -1:
+                self.train_loader.sampler.set_epoch(epoch)
+            pbar = enumerate(self.train_loader)
+            # Update dataloader attributes (optional)
+            if epoch == (self.epochs - self.args.close_mosaic):
+                self._close_dataloader_mosaic()
+                self.train_loader.reset()
+
+            if RANK in (-1, 0):
+                LOGGER.info(self.progress_string())
+                pbar = TQDM(enumerate(self.train_loader), total=nb)
+            self.tloss = None
+            self.optimizer.zero_grad()
+            for i, batch in pbar:
+                self.run_callbacks("on_train_batch_start")
+                # Warmup
+                ni = i + nb * epoch
+                if ni <= nw:
+                    xi = [0, nw]  # x interp
+                    self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
+                    for j, x in enumerate(self.optimizer.param_groups):
+                        # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                        x["lr"] = np.interp(
+                            ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
+                        )
+                        if "momentum" in x:
+                            x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
+
+                # Forward
+                with torch.cuda.amp.autocast(self.amp):
+                    batch = self.preprocess_batch(batch)
+                    (self.loss, self.loss_items), images = self.model(batch, return_images=True)
+                    
+                    if debug and (i % 250):
+                        grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
+                        # Convert tensors to numpy arrays
+                        images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
+                        grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
+
+                        # Normalize grad for visualization
+                        grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
+
+                        for ix in range(images_np.shape[0]):
+                            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
+                            ax[0].imshow(images_np[i])
+                            ax[0].set_title('Image')
+                            ax[1].imshow(grad_np[i], cmap='jet')
+                            ax[1].set_title('Gradient')
+                            ax[2].imshow(images_np[i])
+                            ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
+                            ax[2].set_title('Overlay')
+
+                            save_dir_attr = "figures/attributions"
+                            if not os.path.exists(save_dir_attr):
+                                os.makedirs(save_dir_attr)
+                            plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
+                            plt.close(fig)
+                    
+                    if RANK != -1:
+                        self.loss *= world_size
+                    self.tloss = (
+                        (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
+                    )
+
+                # Backward
+                self.scaler.scale(self.loss).backward()
+
+                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
+                if ni - last_opt_step >= self.accumulate:
+                    self.optimizer_step()
+                    last_opt_step = ni
+
+                    # Timed stopping
+                    if self.args.time:
+                        self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
+                        if RANK != -1:  # if DDP training
+                            broadcast_list = [self.stop if RANK == 0 else None]
+                            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                            self.stop = broadcast_list[0]
+                        if self.stop:  # training time exceeded
+                            break
+
+                # Log
+                mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"  # (GB)
+                loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
+                losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
+                if RANK in (-1, 0):
+                    pbar.set_description(
+                        ("%11s" * 2 + "%11.4g" * (2 + loss_len))
+                        % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
+                    )
+                    self.run_callbacks("on_batch_end")
+                    if self.args.plots and ni in self.plot_idx:
+                        self.plot_training_samples(batch, ni)
+
+                self.run_callbacks("on_train_batch_end")
+
+            self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
+            self.run_callbacks("on_train_epoch_end")
+            if RANK in (-1, 0):
+                final_epoch = epoch + 1 == self.epochs
+                self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
+
+                # Validation
+                if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
+                    or final_epoch or self.stopper.possible_stop or self.stop:
+                    self.metrics, self.fitness = self.validate()
+                self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
+                self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
+                if self.args.time:
+                    self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
+
+                # Save model
+                if self.args.save or final_epoch:
+                    self.save_model()
+                    self.run_callbacks("on_model_save")
+
+            # Scheduler
+            t = time.time()
+            self.epoch_time = t - self.epoch_time_start
+            self.epoch_time_start = t
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()'
+                if self.args.time:
+                    mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
+                    self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
+                    self._setup_scheduler()
+                    self.scheduler.last_epoch = self.epoch  # do not move
+                    self.stop |= epoch >= self.epochs  # stop if exceeded epochs
+                self.scheduler.step()
+            self.run_callbacks("on_fit_epoch_end")
+            torch.cuda.empty_cache()  # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
+
+            # Early Stopping
+            if RANK != -1:  # if DDP training
+                broadcast_list = [self.stop if RANK == 0 else None]
+                dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                self.stop = broadcast_list[0]
+            if self.stop:
+                break  # must break all DDP ranks
+            epoch += 1
+
+        if RANK in (-1, 0):
+            # Do final val with best.pt
+            LOGGER.info(
+                f"\n{epoch - self.start_epoch + 1} epochs completed in "
+                f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
+            )
+            self.final_eval()
+            if self.args.plots:
+                self.plot_metrics()
+            self.run_callbacks("on_train_end")
+        torch.cuda.empty_cache()
+        self.run_callbacks("teardown")
+
+    def save_model(self):
+        """Save model training checkpoints with additional metadata."""
+        import pandas as pd  # scope for faster startup
+
+        metrics = {**self.metrics, **{"fitness": self.fitness}}
+        results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
+        ckpt = {
+            "epoch": self.epoch,
+            "best_fitness": self.best_fitness,
+            "model": deepcopy(de_parallel(self.model)).half(),
+            "ema": deepcopy(self.ema.ema).half(),
+            "updates": self.ema.updates,
+            "optimizer": self.optimizer.state_dict(),
+            "train_args": vars(self.args),  # save as dict
+            "train_metrics": metrics,
+            "train_results": results,
+            "date": datetime.now().isoformat(),
+            "version": __version__,
+            "license": "AGPL-3.0 (https://ultralytics.com/license)",
+            "docs": "https://docs.ultralytics.com",
+        }
+
+        # Save last and best
+        torch.save(ckpt, self.last)
+        if self.best_fitness == self.fitness:
+            torch.save(ckpt, self.best)
+        if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
+            torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
+
+    @staticmethod
+    def get_dataset(data):
+        """
+        Get train, val path from data dict if it exists.
+
+        Returns None if data format is not recognized.
+        """
+        return data["train"], data.get("val") or data.get("test")
+
+    def setup_model(self):
+        """Load/create/download model for any task."""
+        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
+            return
+
+        model, weights = self.model, None
+        ckpt = None
+        if str(model).endswith(".pt"):
+            weights, ckpt = attempt_load_one_weight(model)
+            cfg = ckpt["model"].yaml
+        else:
+            cfg = model
+        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
+        return ckpt
+
+    def optimizer_step(self):
+        """Perform a single step of the training optimizer with gradient clipping and EMA update."""
+        self.scaler.unscale_(self.optimizer)  # unscale gradients
+        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+        self.optimizer.zero_grad()
+        if self.ema:
+            self.ema.update(self.model)
+
+    def preprocess_batch(self, batch):
+        """Allows custom preprocessing model inputs and ground truths depending on task type."""
+        return batch
+
+    def validate(self):
+        """
+        Runs validation on test set using self.validator.
+
+        The returned dict is expected to contain "fitness" key.
+        """
+        metrics = self.validator(self)
+        fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
+        if not self.best_fitness or self.best_fitness < fitness:
+            self.best_fitness = fitness
+        return metrics, fitness
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Get model and raise NotImplementedError for loading cfg files."""
+        raise NotImplementedError("This task trainer doesn't support loading cfg files")
+
+    def get_validator(self):
+        """Returns a NotImplementedError when the get_validator function is called."""
+        raise NotImplementedError("get_validator function not implemented in trainer")
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Returns dataloader derived from torch.data.Dataloader."""
+        raise NotImplementedError("get_dataloader function not implemented in trainer")
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """Build dataset."""
+        raise NotImplementedError("build_dataset function not implemented in trainer")
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Note:
+            This is not needed for classification but necessary for segmentation & detection
+        """
+        return {"loss": loss_items} if loss_items is not None else ["loss"]
+
+    def set_model_attributes(self):
+        """To set or update model parameters before training."""
+        self.model.names = self.data["names"]
+
+    def build_targets(self, preds, targets):
+        """Builds target tensors for training YOLO model."""
+        pass
+
+    def progress_string(self):
+        """Returns a string describing training progress."""
+        return ""
+
+    # TODO: may need to put these following functions into callback
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples during YOLO training."""
+        pass
+
+    def plot_training_labels(self):
+        """Plots training labels for YOLO model."""
+        pass
+
+    def save_metrics(self, metrics):
+        """Saves training metrics to a CSV file."""
+        keys, vals = list(metrics.keys()), list(metrics.values())
+        n = len(metrics) + 1  # number of cols
+        s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n")  # header
+        with open(self.csv, "a") as f:
+            f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
+
+    def plot_metrics(self):
+        """Plot and display metrics visually."""
+        pass
+
+    def on_plot(self, name, data=None):
+        """Registers plots (e.g. to be consumed in callbacks)"""
+        path = Path(name)
+        self.plots[path] = {"data": data, "timestamp": time.time()}
+
+    def final_eval(self):
+        """Performs final evaluation and validation for object detection YOLO model."""
+        for f in self.last, self.best:
+            if f.exists():
+                strip_optimizer(f)  # strip optimizers
+                if f is self.best:
+                    LOGGER.info(f"\nValidating {f}...")
+                    self.validator.args.plots = self.args.plots
+                    self.metrics = self.validator(model=f)
+                    self.metrics.pop("fitness", None)
+                    self.run_callbacks("on_fit_epoch_end")
+
+    def check_resume(self, overrides):
+        """Check if resume checkpoint exists and update arguments accordingly."""
+        resume = self.args.resume
+        if resume:
+            try:
+                exists = isinstance(resume, (str, Path)) and Path(resume).exists()
+                last = Path(check_file(resume) if exists else get_latest_run())
+
+                # Check that resume data YAML exists, otherwise strip to force re-download of dataset
+                ckpt_args = attempt_load_weights(last).args
+                if not Path(ckpt_args["data"]).exists():
+                    ckpt_args["data"] = self.args.data
+
+                resume = True
+                self.args = get_cfg(ckpt_args)
+                self.args.model = self.args.resume = str(last)  # reinstate model
+                for k in "imgsz", "batch", "device":  # allow arg updates to reduce memory or update device on resume
+                    if k in overrides:
+                        setattr(self.args, k, overrides[k])
+
+            except Exception as e:
+                raise FileNotFoundError(
+                    "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
+                    "i.e. 'yolo train resume model=path/to/last.pt'"
+                ) from e
+        self.resume = resume
+
+    def resume_training(self, ckpt):
+        """Resume YOLO training from given epoch and best fitness."""
+        if ckpt is None or not self.resume:
+            return
+        best_fitness = 0.0
+        start_epoch = ckpt["epoch"] + 1
+        if ckpt["optimizer"] is not None:
+            self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
+            best_fitness = ckpt["best_fitness"]
+        if self.ema and ckpt.get("ema"):
+            self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
+            self.ema.updates = ckpt["updates"]
+        assert start_epoch > 0, (
+            f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
+            f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
+        )
+        LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
+        if self.epochs < start_epoch:
+            LOGGER.info(
+                f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
+            )
+            self.epochs += ckpt["epoch"]  # finetune additional epochs
+        self.best_fitness = best_fitness
+        self.start_epoch = start_epoch
+        if start_epoch > (self.epochs - self.args.close_mosaic):
+            self._close_dataloader_mosaic()
+
+    def _close_dataloader_mosaic(self):
+        """Update dataloaders to stop using mosaic augmentation."""
+        if hasattr(self.train_loader.dataset, "mosaic"):
+            self.train_loader.dataset.mosaic = False
+        if hasattr(self.train_loader.dataset, "close_mosaic"):
+            LOGGER.info("Closing dataloader mosaic")
+            self.train_loader.dataset.close_mosaic(hyp=self.args)
+
+    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
+        """
+        Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
+        weight decay, and number of iterations.
+
+        Args:
+            model (torch.nn.Module): The model for which to build an optimizer.
+            name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
+                based on the number of iterations. Default: 'auto'.
+            lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+            momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
+            decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+            iterations (float, optional): The number of iterations, which determines the optimizer if
+                name is 'auto'. Default: 1e5.
+
+        Returns:
+            (torch.optim.Optimizer): The constructed optimizer.
+        """
+
+        g = [], [], []  # optimizer parameter groups
+        bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
+        if name == "auto":
+            LOGGER.info(
+                f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+                f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+                f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+            )
+            nc = getattr(model, "nc", 10)  # number of classes
+            lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
+            name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
+            self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam
+
+        for module_name, module in model.named_modules():
+            for param_name, param in module.named_parameters(recurse=False):
+                fullname = f"{module_name}.{param_name}" if module_name else param_name
+                if "bias" in fullname:  # bias (no decay)
+                    g[2].append(param)
+                elif isinstance(module, bn):  # weight (no decay)
+                    g[1].append(param)
+                else:  # weight (with decay)
+                    g[0].append(param)
+
+        if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
+            optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+        elif name == "RMSProp":
+            optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
+        elif name == "SGD":
+            optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+        else:
+            raise NotImplementedError(
+                f"Optimizer '{name}' not found in list of available optimizers "
+                f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
+                "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
+            )
+
+        optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
+        optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
+        LOGGER.info(
+            f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
+            f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
+        )
+        return optimizer
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index 9ec803a7..62cc59f8 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -206,7 +206,7 @@ class BasePredictor:
         self.vid_writer = {}
 
     @smart_inference_mode()
-    def stream_inference(self, source=None, model=None, *args, **kwargs):
+    def stream_inference(self, source=None, model=None, return_images = False, *args, **kwargs):
         """Streams real-time inference on camera feed and saves results to file."""
         if self.args.verbose:
             LOGGER.info("")
@@ -243,6 +243,9 @@ class BasePredictor:
                 with profilers[0]:
                     im = self.preprocess(im0s)
 
+                if return_images:
+                    im = im.requires_grad_(True)
+
                 # Inference
                 with profilers[1]:
                     preds = self.inference(im, *args, **kwargs)
@@ -272,7 +275,7 @@ class BasePredictor:
                     LOGGER.info("\n".join(s))
 
                 self.run_callbacks("on_predict_batch_end")
-                yield from self.results
+                yield from (self.results, im)
 
         # Release assets
         for v in self.vid_writer.values():
diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py
index 5f3e62c1..b4cb6dce 100644
--- a/ultralytics/models/yolo/detect/__init__.py
+++ b/ultralytics/models/yolo/detect/__init__.py
@@ -1,7 +1,8 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 
 from .predict import DetectionPredictor
+from .pgt_train import PGTDetectionTrainer
 from .train import DetectionTrainer
 from .val import DetectionValidator
 
-__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
diff --git a/ultralytics/models/yolo/detect/pgt_train.py b/ultralytics/models/yolo/detect/pgt_train.py
new file mode 100644
index 00000000..be960193
--- /dev/null
+++ b/ultralytics/models/yolo/detect/pgt_train.py
@@ -0,0 +1,144 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import math
+import random
+from copy import copy
+
+import numpy as np
+import torch.nn as nn
+
+from ultralytics.data import build_dataloader, build_yolo_dataset
+from ultralytics.engine.trainer import BaseTrainer
+from ultralytics.engine.pgt_trainer import PGTTrainer
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import DetectionModel
+from ultralytics.utils import LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
+from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
+
+
+class PGTDetectionTrainer(PGTTrainer):
+    """
+    A class extending the BaseTrainer class for training based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.detect import DetectionTrainer
+
+        args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
+        trainer = DetectionTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Construct and return dataloader."""
+        assert mode in ["train", "val"]
+        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+            dataset = self.build_dataset(dataset_path, mode, batch_size)
+        shuffle = mode == "train"
+        if getattr(dataset, "rect", False) and shuffle:
+            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
+            shuffle = False
+        workers = self.args.workers if mode == "train" else self.args.workers * 2
+        return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader
+
+    def preprocess_batch(self, batch):
+        """Preprocesses a batch of images by scaling and converting to float."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
+        if self.args.multi_scale:
+            imgs = batch["img"]
+            sz = (
+                random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
+                // self.stride
+                * self.stride
+            )  # size
+            sf = sz / max(imgs.shape[2:])  # scale factor
+            if sf != 1:
+                ns = [
+                    math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
+                ]  # new shape (stretched to gs-multiple)
+                imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
+            batch["img"] = imgs
+        return batch
+
+    def set_model_attributes(self):
+        """Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
+        # self.args.box *= 3 / nl  # scale to layers
+        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
+        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
+        self.model.nc = self.data["nc"]  # attach number of classes to model
+        self.model.names = self.data["names"]  # attach class names to model
+        self.model.args = self.args  # attach hyperparameters to model
+        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return a YOLO detection model."""
+        model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
+
+    def get_validator(self):
+        """Returns a DetectionValidator for YOLO model validation."""
+        self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+        return yolo.detect.DetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Not needed for classification but necessary for segmentation & detection
+        """
+        keys = [f"{prefix}/{x}" for x in self.loss_names]
+        if loss_items is not None:
+            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
+            return dict(zip(keys, loss_items))
+        else:
+            return keys
+
+    def progress_string(self):
+        """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
+        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+            "Epoch",
+            "GPU_mem",
+            *self.loss_names,
+            "Instances",
+            "Size",
+        )
+
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples with their annotations."""
+        plot_images(
+            images=batch["img"],
+            batch_idx=batch["batch_idx"],
+            cls=batch["cls"].squeeze(-1),
+            bboxes=batch["bboxes"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots metrics from a CSV file."""
+        plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png
+
+    def plot_training_labels(self):
+        """Create a labeled training plot of the YOLO model."""
+        boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
+        cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
+        plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py
index 09592c83..6e29c3d0 100644
--- a/ultralytics/models/yolov10/model.py
+++ b/ultralytics/models/yolov10/model.py
@@ -3,6 +3,8 @@ from ultralytics.nn.tasks import YOLOv10DetectionModel
 from .val import YOLOv10DetectionValidator
 from .predict import YOLOv10DetectionPredictor
 from .train import YOLOv10DetectionTrainer
+from .pgt_train import YOLOv10PGTDetectionTrainer
+# from .pgt_trainer import YOLOv10DetectionTrainer
 
 from huggingface_hub import PyTorchModelHubMixin
 from .card import card_template_text
@@ -30,6 +32,7 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
             "detect": {
                 "model": YOLOv10DetectionModel,
                 "trainer": YOLOv10DetectionTrainer,
+                "pgt_trainer": YOLOv10PGTDetectionTrainer,
                 "validator": YOLOv10DetectionValidator,
                 "predictor": YOLOv10DetectionPredictor,
             },
diff --git a/ultralytics/models/yolov10/pgt_train.py b/ultralytics/models/yolov10/pgt_train.py
new file mode 100644
index 00000000..1ce70145
--- /dev/null
+++ b/ultralytics/models/yolov10/pgt_train.py
@@ -0,0 +1,21 @@
+from ultralytics.models.yolo.detect import DetectionTrainer
+from ultralytics.models.yolo.detect import PGTDetectionTrainer
+from .val import YOLOv10DetectionValidator
+from .model import YOLOv10DetectionModel
+from copy import copy
+from ultralytics.utils import RANK
+
+class YOLOv10PGTDetectionTrainer(PGTDetectionTrainer):
+    def get_validator(self):
+        """Returns a DetectionValidator for YOLO model validation."""
+        self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", 
+        return YOLOv10DetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return a YOLO detection model."""
+        model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 268bd125..00fbadff 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -93,7 +93,7 @@ class BaseModel(nn.Module):
             return self.loss(x, *args, **kwargs)
         return self.predict(x, *args, **kwargs)
 
-    def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
+    def predict(self, x, profile=False, visualize=False, augment=False, embed=None, return_images=False):
         """
         Perform a forward pass through the network.
 
@@ -107,9 +107,12 @@ class BaseModel(nn.Module):
         Returns:
             (torch.Tensor): The last output of the model.
         """
+        if return_images:
+            x = x.requires_grad_(True)
         if augment:
             return self._predict_augment(x)
-        return self._predict_once(x, profile, visualize, embed)
+        out = self._predict_once(x, profile, visualize, embed)
+        return (out, x) if return_images else out
 
     def _predict_once(self, x, profile=False, visualize=False, embed=None):
         """
@@ -140,13 +143,13 @@ class BaseModel(nn.Module):
                     return torch.unbind(torch.cat(embeddings, 1), dim=0)
         return x
 
-    def _predict_augment(self, x):
+    def _predict_augment(self, x, *args, **kwargs):
         """Perform augmentations on input image x and return augmented inference."""
         LOGGER.warning(
             f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
             f"Reverting to single-scale inference instead."
         )
-        return self._predict_once(x)
+        return self._predict_once(x, *args, **kwargs)
 
     def _profile_one_layer(self, m, x, dt):
         """
@@ -260,7 +263,7 @@ class BaseModel(nn.Module):
         if verbose:
             LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
 
-    def loss(self, batch, preds=None):
+    def loss(self, batch, preds=None, return_images=False):
         """
         Compute loss.
 
@@ -271,8 +274,12 @@ class BaseModel(nn.Module):
         if not hasattr(self, "criterion"):
             self.criterion = self.init_criterion()
 
-        preds = self.forward(batch["img"]) if preds is None else preds
-        return self.criterion(preds, batch)
+        preds = self.forward(batch["img"], return_images=return_images) if preds is None else preds
+        if return_images:
+            preds, im = preds
+        loss = self.criterion(preds, batch)
+        out = loss if not return_images else (loss, im)
+        return out
 
     def init_criterion(self):
         """Initialize the loss criterion for the BaseModel."""

From 2a95a652bdbe60e427c3192a2cad2546ab1c7aba Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Tue, 15 Oct 2024 11:53:27 -0400
Subject: [PATCH 02/12] fixed run_val

---
 run_val.py | 58 ++++++++++++++++--------------------------------------
 1 file changed, 17 insertions(+), 41 deletions(-)

diff --git a/run_val.py b/run_val.py
index 6003d5ee..98c5a476 100644
--- a/run_val.py
+++ b/run_val.py
@@ -1,51 +1,27 @@
-from ultralytics import YOLOv10
-import torch
-from PIL import Image
-from torchvision import transforms
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+# from ultralytics import BaseTrainer
+# from ultralytics.engine.trainer import BaseTrainer
+import os
 
-# Define the device
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+# Set CUDA device (only needed for multi-gpu machines) 
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
 
+# model = YOLOv10()
+# model = YOLO()
+# If you want to finetune the model with pretrained weights, you could load the 
+# pretrained weights like below
 # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
-# model = YOLOv10.from_pretrained('jameslahm/yolov10n')
 # or
 # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt
-# model = YOLOv10('yolov10{n/s/m/b/l/x}.pt')
-model = YOLOv10('yolov10n.pt').to(device)
+model = YOLOv10('yolov10n.pt')
 
-# Load the image
-# path = '/home/nielseni6/PythonScripts/Github/yolov10/images/fat-dog.jpg'
-path = '/home/nielseni6/PythonScripts/Github/yolov10/images/The-Cardinal-Bird.jpg'
-image = Image.open(path)
+# Evaluate the model on the validation set
+results = model.val(data='coco.yaml')
 
-# Define the transformation to resize the image, convert it to a tensor, and normalize it
-transform = transforms.Compose([
-    transforms.Resize((640, 640)),
-    transforms.ToTensor(),
-    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-])
-
-# Apply the transformation
-image_tensor = transform(image)
-
-# Add a batch dimension
-image_tensor = image_tensor.unsqueeze(0).to(device)
-image_tensor = image_tensor.requires_grad_(True)
-
-
-# Predict for a specific image
-# results = model.predict(image_tensor, save=True)
-# model.requires_grad_(True)
-
-
-# for p in model.parameters():
-#     p.requires_grad = True
-results = model.predict(image_tensor, save=True)
-
-# Display the results
-for result in results:
-    print(result)
+# Print the evaluation results
+print(results)
 
 # pred = results[0].boxes[0].conf
 

From 38fa59edf2f706ae9fe0c013166b9d6556a685ae Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Wed, 16 Oct 2024 19:26:41 -0400
Subject: [PATCH 03/12] PGT Training now functioning. Run run_pgt_train.py to
 train the model.

---
 run_attribution.py                           |  29 +-
 run_pgt_train.py                             |  47 ++
 ultralytics/engine/model.py                  |   5 +-
 ultralytics/engine/pgt_trainer.py            |  88 +-
 ultralytics/engine/pgt_validator.py          | 345 ++++++++
 ultralytics/models/yolo/detect/__init__.py   |   3 +-
 ultralytics/models/yolo/detect/pgt_val.py    | 300 +++++++
 ultralytics/models/yolo/segment/__init__.py  |   3 +-
 ultralytics/models/yolo/segment/pgt_train.py |  73 ++
 ultralytics/models/yolov10/model.py          |   2 +-
 ultralytics/models/yolov10/val.py            |  23 +-
 ultralytics/nn/tasks.py                      |   6 +-
 ultralytics/utils/loss.py                    |  29 +-
 ultralytics/utils/plaus_functs.py            | 818 +++++++++++++++++++
 ultralytics/utils/plot_functs.py             | 154 ++++
 ultralytics/utils/plotting.py                |   2 +-
 16 files changed, 1895 insertions(+), 32 deletions(-)
 create mode 100644 run_pgt_train.py
 create mode 100644 ultralytics/engine/pgt_validator.py
 create mode 100644 ultralytics/models/yolo/detect/pgt_val.py
 create mode 100644 ultralytics/models/yolo/segment/pgt_train.py
 create mode 100644 ultralytics/utils/plaus_functs.py
 create mode 100644 ultralytics/utils/plot_functs.py

diff --git a/run_attribution.py b/run_attribution.py
index 8c87e592..3a603934 100644
--- a/run_attribution.py
+++ b/run_attribution.py
@@ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO
 # from ultralytics import BaseTrainer
 # from ultralytics.engine.trainer import BaseTrainer
 import os
+from ultralytics.models.yolo.segment import PGTSegmentationTrainer
+
 
 # Set CUDA device (only needed for multi-gpu machines) 
 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
 os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
 
 # model = YOLOv10()
+# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt')  # build from YAML and transfer weights
+
 # model = YOLO()
 # If you want to finetune the model with pretrained weights, you could load the 
 # pretrained weights like below
 # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
 # or
 # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-model = YOLOv10('yolov10n.pt')
+model = YOLOv10('yolov10n.pt', task='segment')
 
-model.train(data='coco.yaml', 
-            trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
-            # Add return_images as input parameter
-            epochs=500, batch=16, imgsz=640,
-            debug=True, # If debug = True, the attributions will be saved in the figures folder
-            )
+args = dict(model='yolov10n.pt', data='coco128-seg.yaml')
+trainer = PGTSegmentationTrainer(overrides=args)
+trainer.train(
+            # debug=True, 
+            #   args = dict(pgt_coeff=0.1),
+              )
+
+# model.train(
+#             # data='coco.yaml', 
+#             data='coco128-seg.yaml', 
+#             trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
+#             # Add return_images as input parameter
+#             epochs=500, batch=16, imgsz=640,
+#             debug=True, # If debug = True, the attributions will be saved in the figures folder
+#             # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml',
+#             # overrides=dict(task="segment"),
+#             )
 
 # Save the trained model
 model.save('yolov10_coco_trained.pt')
diff --git a/run_pgt_train.py b/run_pgt_train.py
new file mode 100644
index 00000000..e308365d
--- /dev/null
+++ b/run_pgt_train.py
@@ -0,0 +1,47 @@
+from ultralytics import YOLOv10, YOLO
+# from ultralytics.engine.pgt_trainer import PGTTrainer
+import os
+from ultralytics.models.yolo.segment import PGTSegmentationTrainer
+import argparse
+
+
+def main(args):
+  # model = YOLOv10()
+
+  # If you want to finetune the model with pretrained weights, you could load the 
+  # pretrained weights like below
+  # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+  # or
+  # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+  model = YOLOv10('yolov10n.pt', task='segment')
+
+  args = dict(model='yolov10n.pt', data='coco.yaml', 
+              epochs=args.epochs, batch=args.batch_size,
+              # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
+              )
+  trainer = PGTSegmentationTrainer(overrides=args)
+  trainer.train(
+        # debug=True, 
+        #   args = dict(pgt_coeff=0.1), # Should add later to config
+          )
+
+  # Save the trained model
+  model.save('yolov10_coco_trained.pt')
+
+  # Evaluate the model on the validation set
+  results = model.val(data='coco.yaml')
+
+  # Print the evaluation results
+  print(results)
+
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
+  parser.add_argument('--device', type=str, default='0', help='CUDA device number')
+  parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
+  parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
+  args = parser.parse_args()
+
+  # Set CUDA device (only needed for multi-gpu machines)
+  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+  os.environ["CUDA_VISIBLE_DEVICES"] = args.device
+  main(args)
\ No newline at end of file
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index be87559c..53cb1f40 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -656,7 +656,10 @@ class Model(nn.Module):
                     pass
 
         self.trainer.hub_session = self.session  # attach optional HUB session
-        self.trainer.train(debug=debug)
+        if debug:
+            self.trainer.train(debug=debug)
+        else:
+            self.trainer.train()
         # Update model and cfg after training
         if RANK in (-1, 0):
             ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py
index 3d99ad41..6b7c973e 100644
--- a/ultralytics/engine/pgt_trainer.py
+++ b/ultralytics/engine/pgt_trainer.py
@@ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import (
     strip_optimizer,
 )
 
-
+from ultralytics.utils.loss import v8DetectionLoss
+from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
+import matplotlib.path as matplotlib_path
 class PGTTrainer:
     """
     BaseTrainer.
@@ -159,6 +161,7 @@ class PGTTrainer:
         self.loss_names = ["Loss"]
         self.csv = self.save_dir / "results.csv"
         self.plot_idx = [0, 1, 2]
+        self.num = int(time.time())
 
         # Callbacks
         self.callbacks = _callbacks or callbacks.get_default_callbacks()
@@ -328,7 +331,7 @@ class PGTTrainer:
         if world_size > 1:
             self._setup_ddp(world_size)
         self._setup_train(world_size)
-
+        
         nb = len(self.train_loader)  # number of batches
         nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
         last_opt_step = -1
@@ -382,37 +385,88 @@ class PGTTrainer:
                 with torch.cuda.amp.autocast(self.amp):
                     batch = self.preprocess_batch(batch)
                     (self.loss, self.loss_items), images = self.model(batch, return_images=True)
+
+                # smask = get_dist_reg(images, batch['masks'])
+
+                # grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
+                # grad = torch.abs(grad)
+
+                # self.args.pgt_coeff = 1.1
+                # plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff)
+                # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
+                # self.loss += plaus_loss
+
+                debug_ = debug
+                if debug_ and (i % 25 == 0):
+                    debug_ = False
+                    # Create a tensor of zeros with the same size as images
+                    mask = torch.zeros_like(images, dtype=torch.float32)
+                    smask = get_dist_reg(images, batch['masks'])
+                    grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
+                    grad = torch.abs(grad)
+
+                    batch_size = images.shape[0]
+                    imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
+                    targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+                    targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+                    gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+                    mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
                     
-                    if debug and (i % 250):
-                        grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
+                    # Iterate over each bounding box and set the corresponding pixels to 1
+                    for irx, bboxes in enumerate(gt_bboxes):
+                        for idx in range(len(bboxes)):
+                            x1, y1, x2, y2 = bboxes[idx]
+                            x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
+                            mask[irx, :, y1:y2, x1:x2] = 1.0
+
+                    save_imgs = True
+                    if save_imgs:
                         # Convert tensors to numpy arrays
                         images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
                         grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
+                        mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
+                        seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
 
                         # Normalize grad for visualization
                         grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
-
+                        
                         for ix in range(images_np.shape[0]):
-                            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
-                            ax[0].imshow(images_np[i])
+                            fig, ax = plt.subplots(1, 6, figsize=(30, 5))
+                            ax[0].imshow(images_np[ix])
                             ax[0].set_title('Image')
-                            ax[1].imshow(grad_np[i], cmap='jet')
+                            ax[1].imshow(grad_np[ix], cmap='jet')
                             ax[1].set_title('Gradient')
-                            ax[2].imshow(images_np[i])
-                            ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
+                            ax[2].imshow(images_np[ix])
+                            ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
                             ax[2].set_title('Overlay')
+                            ax[3].imshow(mask_np[ix], cmap='gray')
+                            ax[3].set_title('Mask')
+                            ax[4].imshow(seg_mask_np[ix], cmap='gray')
+                            ax[4].set_title('Segmentation Mask')
+                            
+                            # Plot image with bounding boxes
+                            ax[5].imshow(images_np[ix])
+                            for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
+                                x1, y1, x2, y2 = bbox
+                                x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
+                                rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
+                                ax[5].add_patch(rect)
+                                ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
+                            ax[5].set_title('Bounding Boxes')
 
-                            save_dir_attr = "figures/attributions"
+                            save_dir_attr = f"figures/attributions/run{self.num}"
                             if not os.path.exists(save_dir_attr):
                                 os.makedirs(save_dir_attr)
                             plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
                             plt.close(fig)
-                    
-                    if RANK != -1:
-                        self.loss *= world_size
-                    self.tloss = (
-                        (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
-                    )
+                    images = images.detach()
+                
+                
+                if RANK != -1:
+                    self.loss *= world_size
+                self.tloss = (
+                    (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
+                )
 
                 # Backward
                 self.scaler.scale(self.loss).backward()
diff --git a/ultralytics/engine/pgt_validator.py b/ultralytics/engine/pgt_validator.py
new file mode 100644
index 00000000..ce411ba9
--- /dev/null
+++ b/ultralytics/engine/pgt_validator.py
@@ -0,0 +1,345 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+"""
+Check a model's accuracy on a test or val split of a dataset.
+
+Usage:
+    $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
+
+Usage - formats:
+    $ yolo mode=val model=yolov8n.pt                 # PyTorch
+                          yolov8n.torchscript        # TorchScript
+                          yolov8n.onnx               # ONNX Runtime or OpenCV DNN with dnn=True
+                          yolov8n_openvino_model     # OpenVINO
+                          yolov8n.engine             # TensorRT
+                          yolov8n.mlpackage          # CoreML (macOS-only)
+                          yolov8n_saved_model        # TensorFlow SavedModel
+                          yolov8n.pb                 # TensorFlow GraphDef
+                          yolov8n.tflite             # TensorFlow Lite
+                          yolov8n_edgetpu.tflite     # TensorFlow Edge TPU
+                          yolov8n_paddle_model       # PaddlePaddle
+                          yolov8n_ncnn_model         # NCNN
+"""
+
+import json
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
+from ultralytics.utils.checks import check_imgsz
+from ultralytics.utils.ops import Profile
+from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
+
+
+class PGTValidator:
+    """
+    BaseValidator.
+
+    A base class for creating validators.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the validator.
+        dataloader (DataLoader): Dataloader to use for validation.
+        pbar (tqdm): Progress bar to update during validation.
+        model (nn.Module): Model to validate.
+        data (dict): Data dictionary.
+        device (torch.device): Device to use for validation.
+        batch_i (int): Current batch index.
+        training (bool): Whether the model is in training mode.
+        names (dict): Class names.
+        seen: Records the number of images seen so far during validation.
+        stats: Placeholder for statistics during validation.
+        confusion_matrix: Placeholder for a confusion matrix.
+        nc: Number of classes.
+        iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
+        jdict (dict): Dictionary to store JSON validation results.
+        speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
+                      batch processing times in milliseconds.
+        save_dir (Path): Directory to save results.
+        plots (dict): Dictionary to store plots for visualization.
+        callbacks (dict): Dictionary to store various callback functions.
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """
+        Initializes a BaseValidator instance.
+
+        Args:
+            dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
+            save_dir (Path, optional): Directory to save results.
+            pbar (tqdm.tqdm): Progress bar for displaying progress.
+            args (SimpleNamespace): Configuration for the validator.
+            _callbacks (dict): Dictionary to store various callback functions.
+        """
+        self.args = get_cfg(overrides=args)
+        self.dataloader = dataloader
+        self.pbar = pbar
+        self.stride = None
+        self.data = None
+        self.device = None
+        self.batch_i = None
+        self.training = True
+        self.names = None
+        self.seen = None
+        self.stats = None
+        self.confusion_matrix = None
+        self.nc = None
+        self.iouv = None
+        self.jdict = None
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+
+        self.save_dir = save_dir or get_save_dir(self.args)
+        (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+        if self.args.conf is None:
+            self.args.conf = 0.001  # default conf=0.001
+        self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
+
+        self.plots = {}
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+
+    # @smart_inference_mode()
+    def __call__(self, trainer=None, model=None):
+        """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
+        gets priority).
+        """
+        self.training = trainer is not None
+        augment = self.args.augment and (not self.training)
+        if self.training:
+            self.device = trainer.device
+            self.data = trainer.data
+            # self.args.half = self.device.type != "cpu"  # force FP16 val during training
+            model = trainer.ema.ema or trainer.model
+            model = model.half() if self.args.half else model.float()
+            # self.model = model
+            self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
+            self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
+            model.eval()
+        else:
+            callbacks.add_integration_callbacks(self)
+            model = AutoBackend(
+                weights=model or self.args.model,
+                device=select_device(self.args.device, self.args.batch),
+                dnn=self.args.dnn,
+                data=self.args.data,
+                fp16=self.args.half,
+            )
+            # self.model = model
+            self.device = model.device  # update device
+            self.args.half = model.fp16  # update half
+            stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
+            imgsz = check_imgsz(self.args.imgsz, stride=stride)
+            if engine:
+                self.args.batch = model.batch_size
+            elif not pt and not jit:
+                self.args.batch = 1  # export.py models default to batch-size 1
+                LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
+
+            if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
+                self.data = check_det_dataset(self.args.data)
+            elif self.args.task == "classify":
+                self.data = check_cls_dataset(self.args.data, split=self.args.split)
+            else:
+                raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
+
+            if self.device.type in ("cpu", "mps"):
+                self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
+            if not pt:
+                self.args.rect = False
+            self.stride = model.stride  # used in get_dataloader() for padding
+            self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
+
+            model.eval()
+            model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup
+
+        self.run_callbacks("on_val_start")
+        dt = (
+            Profile(device=self.device),
+            Profile(device=self.device),
+            Profile(device=self.device),
+            Profile(device=self.device),
+        )
+        bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
+        self.init_metrics(de_parallel(model))
+        self.jdict = []  # empty before each val
+        for batch_i, batch in enumerate(bar):
+            self.run_callbacks("on_val_batch_start")
+            self.batch_i = batch_i
+            # Preprocess
+            with dt[0]:
+                batch = self.preprocess(batch)
+
+            # Inference
+            with dt[1]:
+                preds = model(batch["img"].requires_grad_(True), augment=augment)
+
+            # Loss
+            with dt[2]:
+                if self.training:
+                    self.loss += model.loss(batch, preds)[1]
+
+            # Postprocess
+            with dt[3]:
+                preds = self.postprocess(preds)
+
+            self.update_metrics(preds, batch)
+            if self.args.plots and batch_i < 3:
+                self.plot_val_samples(batch, batch_i)
+                self.plot_predictions(batch, preds, batch_i)
+
+            self.run_callbacks("on_val_batch_end")
+        stats = self.get_stats()
+        self.check_stats(stats)
+        self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
+        self.finalize_metrics()
+        if not (self.args.save_json and self.is_coco and len(self.jdict)):
+            self.print_results()
+        self.run_callbacks("on_val_end")
+        if self.training:
+            model.float()
+            if self.args.save_json and self.jdict:
+                with open(str(self.save_dir / "predictions.json"), "w") as f:
+                    LOGGER.info(f"Saving {f.name}...")
+                    json.dump(self.jdict, f)  # flatten and save
+                stats = self.eval_json(stats)  # update stats
+                stats['fitness'] = stats['metrics/mAP50-95(B)']
+            results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
+            return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
+        else:
+            LOGGER.info(
+                "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
+                % tuple(self.speed.values())
+            )
+            if self.args.save_json and self.jdict:
+                with open(str(self.save_dir / "predictions.json"), "w") as f:
+                    LOGGER.info(f"Saving {f.name}...")
+                    json.dump(self.jdict, f)  # flatten and save
+                stats = self.eval_json(stats)  # update stats
+            if self.args.plots or self.args.save_json:
+                LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
+            return stats
+
+    def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
+        """
+        Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
+
+        Args:
+            pred_classes (torch.Tensor): Predicted class indices of shape(N,).
+            true_classes (torch.Tensor): Target class indices of shape(M,).
+            iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
+            use_scipy (bool): Whether to use scipy for matching (more precise).
+
+        Returns:
+            (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
+        """
+        # Dx10 matrix, where D - detections, 10 - IoU thresholds
+        correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
+        # LxD matrix where L - labels (rows), D - detections (columns)
+        correct_class = true_classes[:, None] == pred_classes
+        iou = iou * correct_class  # zero out the wrong classes
+        iou = iou.cpu().numpy()
+        for i, threshold in enumerate(self.iouv.cpu().tolist()):
+            if use_scipy:
+                # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
+                import scipy  # scope import to avoid importing for all commands
+
+                cost_matrix = iou * (iou >= threshold)
+                if cost_matrix.any():
+                    labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
+                    valid = cost_matrix[labels_idx, detections_idx] > 0
+                    if valid.any():
+                        correct[detections_idx[valid], i] = True
+            else:
+                matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
+                matches = np.array(matches).T
+                if matches.shape[0]:
+                    if matches.shape[0] > 1:
+                        matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
+                        matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+                        # matches = matches[matches[:, 2].argsort()[::-1]]
+                        matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+                    correct[matches[:, 1].astype(int), i] = True
+        return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
+
+    def add_callback(self, event: str, callback):
+        """Appends the given callback."""
+        self.callbacks[event].append(callback)
+
+    def run_callbacks(self, event: str):
+        """Runs all callbacks associated with a specified event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        """Get data loader from dataset path and batch size."""
+        raise NotImplementedError("get_dataloader function not implemented for this validator")
+
+    def build_dataset(self, img_path):
+        """Build dataset."""
+        raise NotImplementedError("build_dataset function not implemented in validator")
+
+    def preprocess(self, batch):
+        """Preprocesses an input batch."""
+        return batch
+
+    def postprocess(self, preds):
+        """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
+        return preds
+
+    def init_metrics(self, model):
+        """Initialize performance metrics for the YOLO model."""
+        pass
+
+    def update_metrics(self, preds, batch):
+        """Updates metrics based on predictions and batch."""
+        pass
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Finalizes and returns all metrics."""
+        pass
+
+    def get_stats(self):
+        """Returns statistics about the model's performance."""
+        return {}
+
+    def check_stats(self, stats):
+        """Checks statistics."""
+        pass
+
+    def print_results(self):
+        """Prints the results of the model's predictions."""
+        pass
+
+    def get_desc(self):
+        """Get description of the YOLO model."""
+        pass
+
+    @property
+    def metric_keys(self):
+        """Returns the metric keys used in YOLO training/validation."""
+        return []
+
+    def on_plot(self, name, data=None):
+        """Registers plots (e.g. to be consumed in callbacks)"""
+        self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
+
+    # TODO: may need to put these following functions into callback
+    def plot_val_samples(self, batch, ni):
+        """Plots validation samples during training."""
+        pass
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots YOLO model predictions on batch images."""
+        pass
+
+    def pred_to_json(self, preds, batch):
+        """Convert predictions to JSON format."""
+        pass
+
+    def eval_json(self, stats):
+        """Evaluate and return JSON format of prediction statistics."""
+        pass
diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py
index b4cb6dce..656966f5 100644
--- a/ultralytics/models/yolo/detect/__init__.py
+++ b/ultralytics/models/yolo/detect/__init__.py
@@ -4,5 +4,6 @@ from .predict import DetectionPredictor
 from .pgt_train import PGTDetectionTrainer
 from .train import DetectionTrainer
 from .val import DetectionValidator
+from .pgt_val import PGTDetectionValidator
 
-__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator"
diff --git a/ultralytics/models/yolo/detect/pgt_val.py b/ultralytics/models/yolo/detect/pgt_val.py
new file mode 100644
index 00000000..04138ff7
--- /dev/null
+++ b/ultralytics/models/yolo/detect/pgt_val.py
@@ -0,0 +1,300 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.data import build_dataloader, build_yolo_dataset, converter
+from ultralytics.engine.validator import BaseValidator
+from ultralytics.engine.pgt_validator import PGTValidator
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+
+
+class PGTDetectionValidator(PGTValidator):
+    """
+    A class extending the BaseValidator class for validation based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.detect import DetectionValidator
+
+        args = dict(model='yolov8n.pt', data='coco8.yaml')
+        validator = DetectionValidator(args=args)
+        validator()
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initialize detection model with necessary variables and settings."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.nt_per_class = None
+        self.is_coco = False
+        self.class_map = None
+        self.args.task = "detect"
+        self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+        self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
+        self.niou = self.iouv.numel()
+        self.lb = []  # for autolabelling
+
+    def preprocess(self, batch):
+        """Preprocesses batch of images for YOLO training."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True)
+        batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+        for k in ["batch_idx", "cls", "bboxes"]:
+            batch[k] = batch[k].to(self.device)
+
+        if self.args.save_hybrid:
+            height, width = batch["img"].shape[2:]
+            nb = len(batch["img"])
+            bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
+            self.lb = (
+                [
+                    torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
+                    for i in range(nb)
+                ]
+                if self.args.save_hybrid
+                else []
+            )  # for autolabelling
+
+        return batch
+
+    def init_metrics(self, model):
+        """Initialize evaluation metrics for YOLO."""
+        val = self.data.get(self.args.split, "")  # validation path
+        self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt")  # is COCO
+        self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
+        self.args.save_json |= self.is_coco  # run on final val if training COCO
+        self.names = model.names
+        self.nc = len(model.names)
+        self.metrics.names = self.names
+        self.metrics.plot = self.args.plots
+        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
+        self.seen = 0
+        self.jdict = []
+        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
+
+    def get_desc(self):
+        """Return a formatted string summarizing class metrics of YOLO model."""
+        return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
+
+    def postprocess(self, preds):
+        """Apply Non-maximum suppression to prediction outputs."""
+        return ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            multi_label=True,
+            agnostic=self.args.single_cls,
+            max_det=self.args.max_det,
+        )
+
+    def _prepare_batch(self, si, batch):
+        """Prepares a batch of images and annotations for validation."""
+        idx = batch["batch_idx"] == si
+        cls = batch["cls"][idx].squeeze(-1)
+        bbox = batch["bboxes"][idx]
+        ori_shape = batch["ori_shape"][si]
+        imgsz = batch["img"].shape[2:]
+        ratio_pad = batch["ratio_pad"][si]
+        if len(cls):
+            bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes
+            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels
+        return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
+
+    def _prepare_pred(self, pred, pbatch):
+        """Prepares a batch of images and annotations for validation."""
+        predn = pred.clone()
+        ops.scale_boxes(
+            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+        )  # native-space pred
+        return predn
+
+    def update_metrics(self, preds, batch):
+        """Metrics."""
+        for si, pred in enumerate(preds):
+            self.seen += 1
+            npr = len(pred)
+            stat = dict(
+                conf=torch.zeros(0, device=self.device),
+                pred_cls=torch.zeros(0, device=self.device),
+                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+            )
+            pbatch = self._prepare_batch(si, batch)
+            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+            nl = len(cls)
+            stat["target_cls"] = cls
+            if npr == 0:
+                if nl:
+                    for k in self.stats.keys():
+                        self.stats[k].append(stat[k])
+                    if self.args.plots:
+                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
+                continue
+
+            # Predictions
+            if self.args.single_cls:
+                pred[:, 5] = 0
+            predn = self._prepare_pred(pred, pbatch)
+            stat["conf"] = predn[:, 4]
+            stat["pred_cls"] = predn[:, 5]
+
+            # Evaluate
+            if nl:
+                stat["tp"] = self._process_batch(predn, bbox, cls)
+                if self.args.plots:
+                    self.confusion_matrix.process_batch(predn, bbox, cls)
+            for k in self.stats.keys():
+                self.stats[k].append(stat[k])
+
+            # Save
+            if self.args.save_json:
+                self.pred_to_json(predn, batch["im_file"][si])
+            if self.args.save_txt:
+                file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
+                self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Set final values for metrics speed and confusion matrix."""
+        self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
+
+    def get_stats(self):
+        """Returns metrics statistics and results dictionary."""
+        stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
+        if len(stats) and stats["tp"].any():
+            self.metrics.process(**stats)
+        self.nt_per_class = np.bincount(
+            stats["target_cls"].astype(int), minlength=self.nc
+        )  # number of targets per class
+        return self.metrics.results_dict
+
+    def print_results(self):
+        """Prints training/validation set metrics per class."""
+        pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
+        LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+        if self.nt_per_class.sum() == 0:
+            LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
+
+        # Print results per class
+        if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+            for i, c in enumerate(self.metrics.ap_class_index):
+                LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
+
+        if self.args.plots:
+            for normalize in True, False:
+                self.confusion_matrix.plot(
+                    save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+                )
+
+    def _process_batch(self, detections, gt_bboxes, gt_cls):
+        """
+        Return correct prediction matrix.
+
+        Args:
+            detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
+                Each detection is of the format: x1, y1, x2, y2, conf, class.
+            labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
+                Each label is of the format: class, x1, y1, x2, y2.
+
+        Returns:
+            (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
+        """
+        iou = box_iou(gt_bboxes, detections[:, :4])
+        return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+    def build_dataset(self, img_path, mode="val", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        """Construct and return dataloader."""
+        dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
+        return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1)  # return dataloader
+
+    def plot_val_samples(self, batch, ni):
+        """Plot validation image samples."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots predicted bounding boxes on input images and saves the result."""
+        plot_images(
+            batch["img"],
+            *output_to_target(preds, max_det=self.args.max_det),
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred
+
+    def save_one_txt(self, predn, save_conf, shape, file):
+        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+        gn = torch.tensor(shape)[[1, 0, 1, 0]]  # normalization gain whwh
+        for *xyxy, conf, cls in predn.tolist():
+            xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
+            line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
+            with open(file, "a") as f:
+                f.write(("%g " * len(line)).rstrip() % line + "\n")
+
+    def pred_to_json(self, predn, filename):
+        """Serialize YOLO predictions to COCO json format."""
+        stem = Path(filename).stem
+        image_id = int(stem) if stem.isnumeric() else stem
+        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
+        for p, b in zip(predn.tolist(), box.tolist()):
+            self.jdict.append(
+                {
+                    "image_id": image_id,
+                    "category_id": self.class_map[int(p[5])],
+                    "bbox": [round(x, 3) for x in b],
+                    "score": round(p[4], 5),
+                }
+            )
+
+    def eval_json(self, stats):
+        """Evaluates YOLO output in JSON format and returns performance statistics."""
+        if self.args.save_json and self.is_coco and len(self.jdict):
+            anno_json = self.data["path"] / "annotations/instances_val2017.json"  # annotations
+            pred_json = self.save_dir / "predictions.json"  # predictions
+            LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
+            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+                check_requirements("pycocotools>=2.0.6")
+                from pycocotools.coco import COCO  # noqa
+                from pycocotools.cocoeval import COCOeval  # noqa
+
+                for x in anno_json, pred_json:
+                    assert x.is_file(), f"{x} file not found"
+                anno = COCO(str(anno_json))  # init annotations api
+                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
+                eval = COCOeval(anno, pred, "bbox")
+                if self.is_coco:
+                    eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # images to eval
+                eval.evaluate()
+                eval.accumulate()
+                eval.summarize()
+                stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2]  # update mAP50-95 and mAP50
+            except Exception as e:
+                LOGGER.warning(f"pycocotools unable to run: {e}")
+        return stats
diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py
index ec1ac799..31e23957 100644
--- a/ultralytics/models/yolo/segment/__init__.py
+++ b/ultralytics/models/yolo/segment/__init__.py
@@ -3,5 +3,6 @@
 from .predict import SegmentationPredictor
 from .train import SegmentationTrainer
 from .val import SegmentationValidator
+from .pgt_train import PGTSegmentationTrainer
 
-__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
+__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer"
diff --git a/ultralytics/models/yolo/segment/pgt_train.py b/ultralytics/models/yolo/segment/pgt_train.py
new file mode 100644
index 00000000..7e7712b9
--- /dev/null
+++ b/ultralytics/models/yolo/segment/pgt_train.py
@@ -0,0 +1,73 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from copy import copy
+
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import SegmentationModel, DetectionModel
+from ultralytics.utils import DEFAULT_CFG, RANK
+from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
+from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
+
+class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
+    """
+    A class extending the DetectionTrainer class for training based on a segmentation model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.segment import SegmentationTrainer
+
+        args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
+        trainer = SegmentationTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a SegmentationTrainer object with given arguments."""
+        if overrides is None:
+            overrides = {}
+        overrides["task"] = "segment"
+        super().__init__(cfg, overrides, _callbacks)
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return SegmentationModel initialized with specified config and weights."""
+        if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
+            model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        else:
+            model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+
+        return model
+
+    def get_validator(self):
+        """Return an instance of SegmentationValidator for validation of YOLO model."""
+        
+        if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
+            self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
+            return YOLOv10PGTDetectionValidator(
+                self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+            )
+        else:
+            self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
+            return yolo.segment.SegmentationValidator(
+                self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+            )
+
+    def plot_training_samples(self, batch, ni):
+        """Creates a plot of training sample images with labels and box coordinates."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            masks=batch["masks"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots training/val metrics."""
+        plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png
diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py
index 6e29c3d0..0c8bcb9a 100644
--- a/ultralytics/models/yolov10/model.py
+++ b/ultralytics/models/yolov10/model.py
@@ -1,5 +1,5 @@
 from ultralytics.engine.model import Model
-from ultralytics.nn.tasks import YOLOv10DetectionModel
+from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
 from .val import YOLOv10DetectionValidator
 from .predict import YOLOv10DetectionPredictor
 from .train import YOLOv10DetectionTrainer
diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py
index 19a019c8..a96483ca 100644
--- a/ultralytics/models/yolov10/val.py
+++ b/ultralytics/models/yolov10/val.py
@@ -1,4 +1,4 @@
-from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator
 from ultralytics.utils import ops
 import torch
 
@@ -7,6 +7,27 @@ class YOLOv10DetectionValidator(DetectionValidator):
         super().__init__(*args, **kwargs)
         self.args.save_json |= self.is_coco
 
+    def postprocess(self, preds):
+        if isinstance(preds, dict):
+            preds = preds["one2one"]
+
+        if isinstance(preds, (list, tuple)):
+            preds = preds[0]
+        
+        # Acknowledgement: Thanks to sanha9999 in #190 and #181!
+        if preds.shape[-1] == 6:
+            return preds
+        else:
+            preds = preds.transpose(-1, -2)
+            boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
+            bboxes = ops.xywh2xyxy(boxes)
+            return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
+        
+class YOLOv10PGTDetectionValidator(PGTDetectionValidator):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.args.save_json |= self.is_coco
+
     def postprocess(self, preds):
         if isinstance(preds, dict):
             preds = preds["one2one"]
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 00fbadff..48242826 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -57,7 +57,7 @@ from ultralytics.nn.modules import (
 )
 from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
 from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
-from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
+from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss
 from ultralytics.utils.plotting import feature_visualization
 from ultralytics.utils.torch_utils import (
     fuse_conv_and_bn,
@@ -651,6 +651,10 @@ class WorldModel(DetectionModel):
 class YOLOv10DetectionModel(DetectionModel):
     def init_criterion(self):
         return v10DetectLoss(self)
+    
+class YOLOv10PGTDetectionModel(DetectionModel):
+    def init_criterion(self):
+        return v10PGTDetectLoss(self)
 
 class Ensemble(nn.ModuleList):
     """Ensemble of models."""
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index d0ca9c39..e10b949f 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
 from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
 from .metrics import bbox_iou, probiou
 from .tal import bbox2dist
-
+from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
 
 class VarifocalLoss(nn.Module):
     """
@@ -725,3 +725,30 @@ class v10DetectLoss:
         one2one = preds["one2one"]
         loss_one2one = self.one2one(one2one, batch)
         return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
+
+class v10PGTDetectLoss:
+    def __init__(self, model):
+        self.one2many = v8DetectionLoss(model, tal_topk=10)
+        self.one2one = v8DetectionLoss(model, tal_topk=1)
+    
+    def __call__(self, preds, batch):
+        batch['img'] = batch['img'].requires_grad_(True)
+        one2many = preds["one2many"]
+        loss_one2many = self.one2many(one2many, batch)
+        one2one = preds["one2one"]
+        loss_one2one = self.one2one(one2one, batch)
+
+        loss = loss_one2many[0] + loss_one2one[0]
+        
+        smask = get_dist_reg(batch['img'], batch['masks'])
+
+        grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
+        grad = torch.abs(grad)
+
+        pgt_coeff = 3.0
+        plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
+        # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
+        loss += plaus_loss
+        
+        return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
+    
\ No newline at end of file
diff --git a/ultralytics/utils/plaus_functs.py b/ultralytics/utils/plaus_functs.py
new file mode 100644
index 00000000..3c5b689a
--- /dev/null
+++ b/ultralytics/utils/plaus_functs.py
@@ -0,0 +1,818 @@
+import torch
+import numpy as np
+# from plot_functs import * 
+from .plot_functs import normalize_tensor, overlay_mask, imshow
+import math   
+import time
+import matplotlib.path as mplPath
+from matplotlib.path import Path
+# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
+from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
+from .metrics import bbox_iou
+import torchvision.transforms as T
+
+def plaus_loss_fn(grad, smask, pgt_coeff):
+    ################## Compute the PGT Loss ##################
+    # Positive regularization term for incentivizing pixels near the target to have high attribution
+    dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
+    # Negative regularization term for incentivizing pixels far from the target to have low attribution
+    dist_attr_neg = attr_reg(grad, smask)
+    # Calculate plausibility regularization term
+    # dist_reg = dist_attr_pos - dist_attr_neg
+    dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
+    plaus_reg = (((1.0 + dist_reg) / 2.0))
+    # Calculate plausibility loss
+    plaus_loss = (1 - plaus_reg) * pgt_coeff
+    return plaus_loss
+
+def get_dist_reg(images, seg_mask):
+    seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
+    seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
+    seg_mask[seg_mask > 0] = 1.0
+    
+    smask = torch.zeros_like(seg_mask)
+    sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
+    for k_it, sigma in enumerate(sigmas):
+        # Apply Gaussian blur to the mask
+        kernel_size = int(sigma + 50)
+        if kernel_size % 2 == 0:
+            kernel_size += 1
+        seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
+        if torch.max(seg_mask1) > 1.0:
+            seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
+        smask = torch.max(smask, seg_mask1)
+    return smask
+
+def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
+    """
+    Compute the gradient of an image with respect to a given tensor.
+
+    Args:
+        img (torch.Tensor): The input image tensor.
+        grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
+        norm (bool, optional): Whether to normalize the gradient. Defaults to True.
+        absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
+        grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
+        keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
+
+    Returns:
+        torch.Tensor: The computed attribution map.
+
+    """
+    if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
+        grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
+    else:
+        grad_wrt_outputs = None
+    attribution_map = torch.autograd.grad(grad_wrt, img, 
+                                    grad_outputs=grad_wrt_outputs, 
+                                    create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
+                                    )[0]
+    if absolute:
+        attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
+    if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
+        attribution_map = torch.sum(attribution_map, 1, keepdim=True)
+    if norm:
+        if keepmean:
+            attmean = torch.mean(attribution_map)
+            attmin = torch.min(attribution_map)
+            attmax = torch.max(attribution_map)
+        attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
+        if keepmean:
+            attribution_map -= attribution_map.mean()
+            attribution_map += (attmean / (attmax - attmin))
+        
+    return attribution_map
+
+def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
+    """
+    Generate Gaussian noise based on the input image.
+
+    Args:
+        img (torch.Tensor): Input image.
+        grad_wrt: Gradient with respect to the input image.
+        norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
+        absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
+        grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
+        keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
+
+    Returns:
+        torch.Tensor: Generated Gaussian noise.
+    """
+    
+    gaussian_noise = torch.randn_like(img)
+    
+    if absolute:
+        gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
+    if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
+        gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
+    if norm:
+        if keepmean:
+            attmean = torch.mean(gaussian_noise)
+            attmin = torch.min(gaussian_noise)
+            attmax = torch.max(gaussian_noise)
+        gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
+        if keepmean:
+            gaussian_noise -= gaussian_noise.mean()
+            gaussian_noise += (attmean / (attmax - attmin))
+        
+    return gaussian_noise
+    
+
+def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
+    # TODO: Remove imgs from this function and only take it as input if debug is True
+    """
+    Calculates the plausibility score based on the given inputs.
+
+    Args:
+        imgs (torch.Tensor): The input images.
+        targets_out (torch.Tensor): The output targets.
+        attr (torch.Tensor): The attribute tensor.
+        debug (bool, optional): Whether to enable debug mode. Defaults to False.
+
+    Returns:
+        torch.Tensor: The plausibility score.
+    """
+    # # if imgs is None:
+    # #     imgs = torch.zeros_like(attr)
+    # # with torch.no_grad():
+    # target_inds = targets_out[:, 0].int()
+    # xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
+    # num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
+    # # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
+    # xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
+    # co = xyxy_corners
+    # if corners:
+    #     co = targets_out[:, 2:6].int()
+    # coords_map = torch.zeros_like(attr, dtype=torch.bool)
+    # # rows = np.arange(co.shape[0])
+    # x1, x2 = co[:,1], co[:,3]
+    # y1, y2 = co[:,0], co[:,2]
+    
+    # for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
+    #     coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
+
+    if torch.isnan(attr).any():
+        attr = torch.nan_to_num(attr, nan=0.0)
+    
+    coords_map = get_bbox_map(targets_out, attr)
+    plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
+
+    if debug:
+        for i in range(len(coords_map)):
+            coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
+            test_bbox = torch.zeros_like(imgs[i])
+            test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
+            imshow(test_bbox, save_path='figs/test_bbox')
+            if imgs is None:
+                imgs = torch.zeros_like(attr)
+            imshow(imgs[i], save_path='figs/im0')
+            imshow(attr[i], save_path='figs/attr')
+    
+    # with torch.no_grad():
+    # # att_select = attr[coords_map]
+    # att_select = attr * coords_map.to(torch.float32)
+    # att_total = attr
+    
+    # IoU_num = torch.sum(att_select)
+    # IoU_denom = torch.sum(att_total)
+    
+    # IoU_ = (IoU_num / IoU_denom)
+    # plaus_score = IoU_
+
+    # # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
+    
+    return plaus_score
+
+def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
+    # TODO: Remove imgs from this function and only take it as input if debug is True
+    """
+    Calculates the plausibility score based on the given inputs.
+
+    Args:
+        imgs (torch.Tensor): The input images.
+        targets_out (torch.Tensor): The output targets.
+        attr (torch.Tensor): The attribute tensor.
+        debug (bool, optional): Whether to enable debug mode. Defaults to False.
+
+    Returns:
+        torch.Tensor: The plausibility score.
+    """
+    # if imgs is None:
+    #     imgs = torch.zeros_like(attr)
+    # with torch.no_grad():
+    target_inds = targets_out[:, 0].int()
+    xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
+    num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
+    # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
+    xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
+    co = xyxy_corners
+    if corners:
+        co = targets_out[:, 2:6].int()
+    coords_map = torch.zeros_like(attr, dtype=torch.bool)
+    # rows = np.arange(co.shape[0])
+    x1, x2 = co[:,1], co[:,3]
+    y1, y2 = co[:,0], co[:,2]
+    
+    for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
+        coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
+
+    if torch.isnan(attr).any():
+        attr = torch.nan_to_num(attr, nan=0.0)
+    if debug:
+        for i in range(len(coords_map)):
+            coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
+            test_bbox = torch.zeros_like(imgs[i])
+            test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
+            imshow(test_bbox, save_path='figs/test_bbox')
+            imshow(imgs[i], save_path='figs/im0')
+            imshow(attr[i], save_path='figs/attr')
+    
+    # att_select = attr[coords_map]
+    # with torch.no_grad():
+    # IoU_num = (torch.sum(attr[coords_map]))
+    # IoU_denom = torch.sum(attr)
+    # IoU_ = (IoU_num / (IoU_denom))
+    
+    # IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
+    co = (xyxy_batch * num_pixels).int()
+    x1 = co[:,1] + 1
+    y1 = co[:,0] + 1
+    # with torch.no_grad():
+    attr_ = torch.sum(attr, 1, keepdim=True)
+    corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
+    for ic in range(co.shape[0]):
+        attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
+        attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
+        attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
+        attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
+
+        x_0, y_0 = max_indices_2d(attr0[0])
+        x_1, y_1 = max_indices_2d(attr1[0])
+        x_2, y_2 = max_indices_2d(attr2[0])
+        x_3, y_3 = max_indices_2d(attr3[0])
+
+        y_1 += y1[ic]
+        x_2 += x1[ic]
+        x_3 += x1[ic]
+        y_3 += y1[ic]
+
+        max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
+                                    torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
+                                    torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
+                                    torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
+        if corners_attr is None:
+            corners_attr = max_corners
+        else:
+            corners_attr = torch.cat([corners_attr, max_corners], dim=0)
+        # corners_attr[ic] = max_corners
+        # corners_attr = attr[:,0,:4,0]
+    corners_attr = corners_attr.view(-1, 4)
+    # corners_attr = torch.stack(corners_attr, dim=0)
+    IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
+    plaus_score = IoU_.mean()
+
+    return plaus_score
+
+def max_indices_2d(x_inp):
+    # values, indices = x.reshape(x.size(0), -1).max(dim=-1)
+    torch.max(x_inp,)
+    index = torch.argmax(x_inp)
+    x = index // x_inp.shape[1]
+    y = index % x_inp.shape[1]
+    # x, y = divmod(index.item(), x_inp.shape[1])
+
+    return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
+
+
+def point_in_polygon(poly, grid):
+    # t0 = time.time()
+    num_points = poly.shape[0]
+    j = num_points - 1
+    oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
+    for i in range(num_points):
+        cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
+        cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
+        cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
+        oddNodes = oddNodes ^ (cond1 | cond2) & cond3
+        j = i
+    # t1 = time.time()
+    # print(f'point in polygon time: {t1-t0}')
+    return oddNodes
+    
+def point_in_polygon_gpu(poly, grid):
+    num_points = poly.shape[0]
+    i = torch.arange(num_points)
+    j = (i - 1) % num_points
+    # Expand dimensions
+    # t0 = time.time()
+    poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
+    # t1 = time.time()
+    cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
+    cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
+    cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
+    # t2 = time.time()
+    oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
+    cond = (cond1 | cond2) & cond3
+    # t3 = time.time()
+    # efficiently perform xor using gpu and avoiding cpu as much as possible
+    c = []
+    while len(cond) > 1: 
+        if len(cond) % 2 == 1: # odd number of elements
+            c.append(cond[-1])
+            cond = cond[:-1]
+        cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
+    for c_ in c:
+        cond = torch.bitwise_xor(cond, c_)
+    oddNodes = cond
+    # t4 = time.time()
+    # for c in cond:
+    #     oddNodes = oddNodes ^ c
+    # print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} |  bitwise xor time: {t4-t3}')
+    # print(f'point in polygon time gpu: {t4-t0}')
+    # oddNodes = oddNodes ^ (cond1 | cond2) & cond3
+    return oddNodes
+
+
+def bitmap_for_polygon(poly, h, w):
+    y = torch.arange(h).to(poly.device).float()
+    x = torch.arange(w).to(poly.device).float()
+    grid_y, grid_x = torch.meshgrid(y, x)
+    grid = torch.stack((grid_x, grid_y), dim=-1)
+    bitmap = point_in_polygon(poly, grid)
+    return bitmap.unsqueeze(0)
+
+
+def corners_coords(center_xywh):
+    center_x, center_y, w, h = center_xywh
+    x = center_x - w/2
+    y = center_y - h/2
+    return torch.tensor([x, y, x+w, y+h])
+
+def corners_coords_batch(center_xywh):
+    center_x, center_y = center_xywh[:,0], center_xywh[:,1]
+    w, h = center_xywh[:,2], center_xywh[:,3]
+    x = center_x - w/2
+    y = center_y - h/2
+    return torch.stack([x, y, x+w, y+h], dim=1)
+    
+def normalize_batch(x):
+    """
+    Normalize a batch of tensors along each channel.
+    
+    Args:
+        x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
+        
+    Returns:
+        torch.Tensor: Normalized tensor of the same shape as the input.
+    """
+    mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
+    maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
+    for i in range(x.shape[0]):
+        mins[i] = x[i].min()
+        maxs[i] = x[i].max()
+    x_ = (x - mins) / (maxs - mins)
+    
+    return x_
+
+def get_detections(model_clone, img):
+    """
+    Get detections from a model given an input image and targets.
+
+    Args:
+        model (nn.Module): The model to use for detection.
+        img (torch.Tensor): The input image tensor.
+
+    Returns:
+        torch.Tensor: The detected bounding boxes.
+    """
+    model_clone.eval() # Set model to evaluation mode
+    # Run inference
+    with torch.no_grad():
+        det_out, out = model_clone(img)
+    
+    # model_.train()
+    del img 
+    
+    return det_out, out
+
+def get_labels(det_out, imgs, targets, opt):
+    ###################### Get predicted labels ###################### 
+    nb, _, height, width = imgs.shape  # batch size, channels, height, width 
+    targets_ = targets.clone() 
+    targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device)  # to pixels
+    lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else []  # for autolabelling
+    o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
+    pred_labels = [] 
+    for si, pred in enumerate(o):
+        labels = targets_[targets_[:, 0] == si, 1:]
+        nl = len(labels) 
+        predn = pred.clone()
+        # Get the indices that sort the values in column 5 in ascending order
+        sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
+        # Apply the sorting indices to the tensor
+        sorted_pred = predn[sort_indices]
+        # Remove predictions with less than 0.1 confidence
+        n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
+        sorted_pred = sorted_pred[:n_conf]
+        new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
+        preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
+        preds[:, 2:] = xyxy2xywh(preds[:, 2:])  # xywh
+        gn = torch.tensor([width, height])[[1, 0, 1, 0]]  # normalization gain whwh
+        preds[:, 2:] /= gn.to(imgs.device)  # from pixels
+        pred_labels.append(preds)
+    pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
+    
+    return pred_labels
+    ##################################################################
+
+from torchvision.utils import make_grid
+
+def get_center_coords(attr):
+    img_tensor = img_tensor / img_tensor.max()
+
+    # Define a brightness threshold
+    threshold = 0.95
+
+    # Create a binary mask of the bright pixels
+    mask = img_tensor > threshold
+
+    # Get the coordinates of the bright pixels
+    y_coords, x_coords = torch.where(mask)
+
+    # Calculate the centroid of the bright pixels
+    centroid_x = x_coords.float().mean().item()
+    centroid_y = y_coords.float().mean().item()
+
+    print(f'The central bright point is at ({centroid_x}, {centroid_y})')
+    
+    return
+
+
+def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
+    """
+    Compute the distance grids from each pixel to the target coordinates.
+
+    Args:
+        attr (torch.Tensor): Attribution maps.
+        targets (torch.Tensor): Target coordinates.
+        focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
+        debug (bool, optional): Whether to visualize debug information. Defaults to False.
+
+    Returns:
+        torch.Tensor: Distance grids.
+    """
+    
+    # Assign the height and width of the input tensor to variables
+    height, width = attr.shape[-1], attr.shape[-2]
+    
+    # attr = torch.abs(attr) # Take absolute values of gradients
+    # attr = normalize_batch(attr) # Normalize attribution maps per image in batch
+
+    # Create a grid of indices
+    xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
+    idx_grid = torch.stack((xx, yy), dim=-1).float()
+    
+    # Expand the grid to match the batch size
+    idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
+    
+    # Initialize a list to store the distance grids
+    dist_grids_ = [[]] * attr.shape[0]
+
+    # Loop over batches
+    for j in range(attr.shape[0]):
+        # Get the rows where the first column is the current unique value
+        rows = targets[targets[:, 0] == j]
+        
+        if len(rows) != 0: 
+            # Create a tensor for the target coordinates
+            xy = rows[:,2:4] # y, x
+            # Flip the x and y coordinates and scale them to the image size
+            xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
+            xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True) 
+            
+            # Compute the Euclidean distance from each pixel to the target coordinates
+            dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
+
+            # Pick the closest distance to any target for each pixel 
+            dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0) 
+            dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
+        else:
+            # Set grid to zero if no targets are present
+            dist_grid = torch.zeros_like(attr[j])
+            
+        dist_grids_[j] = dist_grid
+    # Convert the list of distance grids to a tensor for faster computation
+    dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
+    if torch.isnan(dist_grids).any():
+        dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
+
+    if debug:
+        for i in range(len(dist_grids)):
+            if ((i % 8) == 0):
+                grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
+                imshow(grid_show, save_path='figs/dist_grids')
+                if imgs is None:
+                    imgs = torch.zeros_like(attr)
+                imshow(imgs[i], save_path='figs/im0')
+                img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
+                imshow(img_overlay, save_path='figs/dist_grid_overlay')
+                weighted_attr = (dist_grids[i] * attr[i])
+                imshow(weighted_attr, save_path='figs/weighted_attr')
+                imshow(attr[i], save_path='figs/attr')
+
+    return dist_grids
+
+def attr_reg(attribution_map, distance_map):
+
+    # dist_attr = distance_map * attribution_map 
+    dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3)) 
+    # del distance_map, attribution_map
+    return dist_attr
+
+def get_bbox_map(targets_out, attr, corners=False):
+    target_inds = targets_out[:, 0].int()
+    xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
+    num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
+    # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
+    xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
+    co = xyxy_corners
+    if corners:
+        co = targets_out[:, 2:6].int()
+    coords_map = torch.zeros_like(attr, dtype=torch.bool)
+    # rows = np.arange(co.shape[0])
+    x1, x2 = co[:,1], co[:,3]
+    y1, y2 = co[:,0], co[:,2]
+    
+    for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
+        coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
+    
+    bbox_map = coords_map.to(torch.float32)
+
+    return bbox_map
+######################################## BCE #######################################
+def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
+    # if imgs is None:
+    #     imgs = torch.zeros_like(attribution_map)
+    # Calculate Plausibility IoU with attribution maps
+    # attribution_map.retains_grad = True
+    if not only_loss:
+        plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
+    else:
+        plaus_score = torch.tensor(0.0)
+    
+    # attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
+
+    # Calculate distance regularization
+    distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
+    # distance_map = torch.ones_like(attribution_map)
+    
+    if opt.dist_x_bbox:
+        bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
+        distance_map[bbox_map] = 0.0
+        # distance_map = distance_map * (1 - bbox_map)
+
+    # Positive regularization term for incentivizing pixels near the target to have high attribution
+    dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
+    # Negative regularization term for incentivizing pixels far from the target to have low attribution
+    dist_attr_neg = attr_reg(attribution_map, distance_map)
+    # Calculate plausibility regularization term
+    # dist_reg = dist_attr_pos - dist_attr_neg
+    dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
+    # dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))) 
+    # dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
+    #                             torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
+    #                             / 2.5
+
+    if opt.bbox_coeff != 0.0:
+        bbox_map = get_bbox_map(targets, attribution_map)
+        attr_bbox_pos = attr_reg(attribution_map, bbox_map)
+        attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
+        bbox_reg = attr_bbox_pos - attr_bbox_neg
+        # bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
+    else:
+        bbox_reg = 0.0
+
+    bbox_map = get_bbox_map(targets, attribution_map)
+    plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
+    # iou_loss = (1.0 - plaus_score)
+
+    if not opt.dist_reg_only:
+        dist_reg_loss = (((1.0 + dist_reg) / 2.0))
+        plaus_reg = (plaus_score * opt.iou_coeff) + \
+                    (((dist_reg_loss * opt.dist_coeff) + \
+                      (bbox_reg * opt.bbox_coeff))\
+                    # ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
+                    # / (plaus_score) \
+                    )
+    else:
+        plaus_reg = (((1.0 + dist_reg) / 2.0))
+        # plaus_reg = dist_reg 
+    # Calculate plausibility loss
+    plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
+    # plaus_loss = (plaus_reg) * opt.pgt_coeff
+    if only_loss:
+        return plaus_loss
+    if not debug:
+        return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
+    else:
+        return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
+
+####################################################################################
+#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
+####################################################################################
+
+def generate_vanilla_grad(model, input_tensor, loss_func = None, 
+                          targets_list=None, targets=None, metric=None, out_num = 1, 
+                          n_max_labels=3, norm=True, abs=True, grayscale=True, 
+                          class_specific_attr = True, device='cpu'):    
+    """
+    Generate vanilla gradients for the given model and input tensor.
+
+    Args:
+        model (nn.Module): The model to generate gradients for.
+        input_tensor (torch.Tensor): The input tensor for which gradients are computed.
+        loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
+        targets_list (list, optional): The list of target tensors. Defaults to None.
+        metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
+        out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
+        n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
+        norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
+        abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
+        grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
+        class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
+        device (str, optional): The device to use for computation. Defaults to 'cpu'.
+    
+    Returns:
+        torch.Tensor: The generated vanilla gradients.
+    """
+    # Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
+    train_mode = model.training
+    if not train_mode:
+        model.train()
+    
+    input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
+    model.zero_grad() # Zero gradients
+    inpt = input_tensor
+    # Forward pass
+    train_out = model(inpt) # training outputs (no inference outputs in train mode)
+    
+    # train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
+    # train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
+    # train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
+    
+    if class_specific_attr:
+        n_attr_list, index_classes = [], []
+        for i in range(len(input_tensor)):
+            if len(targets_list[i]) > n_max_labels:
+                targets_list[i] = targets_list[i][:n_max_labels]
+            if targets_list[i].numel() != 0:
+                # unique_classes = torch.unique(targets_list[i][:,1])
+                class_numbers = targets_list[i][:,1]
+                index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
+                num_attrs = len(targets_list[i])
+                # index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
+                # num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
+                n_attr_list.append(num_attrs)
+            else:
+                index_classes.append([0, 1, 2, 3, 4])
+                n_attr_list.append(0)
+    
+        targets_list_filled = [targ.clone().detach() for targ in targets_list]
+        labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
+        max_labels = np.max(labels_len)
+        max_index = np.argmax(labels_len)
+        for i in range(len(targets_list)):
+            # targets_list_filled[i] = targets_list[i]
+            if len(targets_list_filled[i]) < max_labels:
+                tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
+                targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
+            else:
+                targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
+        for i in range(len(targets_list_filled)-1,-1,-1):
+            if targets_list_filled[i].numel() == 0:
+                targets_list_filled.pop(i)
+        targets_list_filled = torch.cat(targets_list_filled)
+    
+    n_img_attrs = len(input_tensor) if class_specific_attr else 1
+    n_img_attrs = 1 if loss_func else n_img_attrs
+    
+    attrs_batch = []
+    for i_batch in range(n_img_attrs):
+        if loss_func and class_specific_attr:
+            i_batch = max_index
+        # inpt = input_tensor[i_batch].unsqueeze(0)
+        # ##################################################################
+        # model.zero_grad() # Zero gradients
+        # train_out = model(inpt)  # training outputs (no inference outputs in train mode)
+        # ##################################################################
+        n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
+        n_label_attrs = 1 if not class_specific_attr else n_label_attrs
+        attrs_img = []
+        for i_attr in range(n_label_attrs):
+            if loss_func is None:
+                grad_wrt = train_out[out_num]
+                if class_specific_attr:
+                    grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
+                grad_wrt_outputs = torch.ones_like(grad_wrt)
+            else:
+                # if class_specific_attr:
+                #     targets = targets_list[:][i_attr]
+                # n_targets = len(targets_list[i_batch])
+                if class_specific_attr:
+                    target_indiv = targets_list_filled[:,i_attr] # batch image input
+                else:
+                    target_indiv = targets
+                # target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
+                # target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
+                    
+                try:
+                    loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)  # loss scaled by batch_size
+                except:
+                    target_indiv = target_indiv.to(device)
+                    inpt = inpt.to(device)
+                    for tro in train_out:
+                        tro = tro.to(device)
+                    print("Error in loss function, trying again with device specified")
+                    loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
+                grad_wrt = loss
+                grad_wrt_outputs = None
+            
+            model.zero_grad() # Zero gradients
+            gradients = torch.autograd.grad(grad_wrt, inpt, 
+                                                grad_outputs=grad_wrt_outputs, 
+                                                retain_graph=True, 
+                                                # create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
+                                                )
+
+            # Convert gradients to numpy array and back to ensure full separation from graph
+            # attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
+            attribution_map = gradients[0]#.clone().detach() # without converting to numpy
+            
+            if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
+                attribution_map = torch.sum(attribution_map, 1, keepdim=True)
+            if abs:
+                attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
+            if norm:
+                attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
+            attrs_img.append(attribution_map)
+        if len(attrs_img) == 0:
+            attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
+        else:
+            attrs_batch.append(torch.stack(attrs_img).to(device))
+
+    # out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
+    # out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
+    out_attr = attrs_batch
+    # Set model back to original mode
+    if not train_mode:
+        model.eval()
+    
+    return out_attr
+
+class RVNonLinearFunc(torch.nn.Module):
+    """
+    Custom Bayesian ReLU activation function for random variables.
+
+    Attributes:
+        None
+    """
+    def __init__(self, func):
+        super(RVNonLinearFunc, self).__init__()
+        self.func = func
+
+    def forward(self, mu_in, Sigma_in):
+        """
+        Forward pass of the Bayesian ReLU activation function.
+
+        Args:
+            mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
+                representing the mean input to the ReLU activation function.
+            Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
+                representing the covariance input to the ReLU activation function.
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
+                including the mean of the output and the covariance of the output.
+        """
+        # Collect stats
+        batch_size = mu_in.size(0)
+       
+        # Mean
+        mu_out = self.func(mu_in)
+        
+        # Compute the derivative of the ReLU activation function with respect to the input mean
+        gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
+
+        # add an extra dimension to gradi at position 2 and 1
+        grad1 = gradi.unsqueeze(dim=2)
+        grad2 = gradi.unsqueeze(dim=1)
+       
+        # compute the outer product of grad1 and grad2
+        outer_product = torch.bmm(grad1, grad2)
+       
+        # element-wise multiply Sigma_in with the outer product
+        # and return the result
+        Sigma_out = torch.mul(Sigma_in, outer_product)
+
+        return mu_out, Sigma_out
+
diff --git a/ultralytics/utils/plot_functs.py b/ultralytics/utils/plot_functs.py
new file mode 100644
index 00000000..4b403331
--- /dev/null
+++ b/ultralytics/utils/plot_functs.py
@@ -0,0 +1,154 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+
+class Subplots:
+    def __init__(self, figsize = (40, 5)):
+        self.fig = plt.figure(figsize=figsize)
+        
+    def plot_img_list(self, img_list, savedir='figs/test', 
+                    nrows = 1, rownum = 0, 
+                    hold = False, coltitles=[], rowtitle=''):
+        
+        for i, img in enumerate(img_list):
+            try:
+                npimg = img.clone().detach().cpu().numpy()
+            except:
+                npimg = img
+            tpimg = np.transpose(npimg, (1, 2, 0))
+            lenrow = int((len(img_list)))
+            ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
+            if len(coltitles) > i:
+                ax.set_title(coltitles[i])
+            if i == 0:
+                ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
+                xycoords='axes fraction', textcoords='offset points',
+                size='large', ha='center', va='baseline')
+                # ax.set_ylabel(rowtitle, rotation=90)
+            ax.imshow(tpimg)
+            ax.axis('off')
+
+        if not hold:
+            self.fig.tight_layout()
+            plt.savefig(f'{savedir}.png')
+            plt.clf()
+            plt.close('all')
+            
+                    
+def VisualizeNumpyImageGrayscale(image_3d):
+    r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
+    """
+    vmin = np.min(image_3d)
+    image_2d = image_3d - vmin
+    vmax = np.max(image_2d)
+    return (image_2d / vmax)
+
+def normalize_numpy(image_3d):
+    r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
+    """
+    vmin = np.min(image_3d)
+    image_2d = image_3d - vmin
+    vmax = np.max(image_2d)
+    return (image_2d / vmax)
+
+# def normalize_tensor(image_3d): 
+#     r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
+#     """
+#     vmin = torch.min(image_3d)
+#     image_2d = image_3d - vmin
+#     vmax = torch.max(image_2d)
+#     return (image_2d / vmax)
+
+def normalize_tensor(image_3d): 
+    r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
+    """
+    image_2d = (image_3d - torch.min(image_3d))
+    return (image_2d / torch.max(image_2d))
+
+def format_img(img_):
+    np_img = img_.numpy()
+    tp_img = np.transpose(np_img, (1, 2, 0))
+    return tp_img
+
+def imshow(img, save_path=None):
+    try:
+        npimg = img.clone().detach().cpu().numpy()
+    except:
+        npimg = img
+    tpimg = np.transpose(npimg, (1, 2, 0))
+    plt.imshow(tpimg)
+    # plt.axis('off')
+    plt.tight_layout()
+    if save_path != None:
+        plt.savefig(str(str(save_path) + ".png"))
+    #plt.show()a
+
+def imshow_img(img, imsave_path):
+    # works for tensors and numpy arrays
+    try:
+        npimg = VisualizeNumpyImageGrayscale(img.numpy())
+    except:
+        npimg = VisualizeNumpyImageGrayscale(img)
+    npimg = np.transpose(npimg, (2, 0, 1))
+    imshow(npimg, save_path=imsave_path)
+    print("Saving image as ", imsave_path)
+    
+def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
+    model.train()
+    model.to(device)
+    img = img.to(device)
+    img.requires_grad_(True)
+    labels.to(device).requires_grad_(True)
+    model.requires_grad_(True)
+    cuda = device.type != 'cpu'
+    scaler = amp.GradScaler(enabled=cuda)
+    pred = model(img)
+    # out, train_out = model(img, augment=augment)  # inference and training outputs
+    loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3]  # box, obj, cls
+    # loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
+    # loss = torch.sum(loss).requires_grad_(True)
+    
+    with torch.autograd.set_detect_anomaly(True):
+        scaler.scale(loss).backward(inputs=img)
+    # loss.backward()
+    
+#    S_c = torch.max(pred[0].data, 0)[0]
+    Sc_dx = img.grad
+    model.eval()
+    Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
+    return Sc_dx
+
+def calculate_snr(img, attr, dB=True):
+    try:
+        img_np = img.detach().cpu().numpy()
+        attr_np = attr.detach().cpu().numpy()
+    except:
+        img_np = img
+        attr_np = attr
+    
+    # Calculate the signal power
+    signal_power = np.mean(img_np**2)
+
+    # Calculate the noise power
+    noise_power = np.mean(attr_np**2)
+
+    if dB == True:
+        # Calculate SNR in dB
+        snr = 10 * np.log10(signal_power / noise_power)
+    else:
+        # Calculate SNR
+        snr = signal_power / noise_power
+
+    return snr
+
+def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
+    
+    cmap = plt.get_cmap(colormap)
+    npmask = np.array(mask.clone().detach().cpu().squeeze(0))
+    # cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
+    cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
+    overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
+    overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
+    
+    return overlayed_tensor
\ No newline at end of file
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index d0215ba5..0e99c722 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -717,7 +717,7 @@ def plot_images(
 ):
     """Plot image grid with labels."""
     if isinstance(images, torch.Tensor):
-        images = images.cpu().float().numpy()
+        images = images.detach().cpu().float().numpy()
     if isinstance(cls, torch.Tensor):
         cls = cls.cpu().numpy()
     if isinstance(bboxes, torch.Tensor):

From 3a449d5a6c5c1c2fddb3de11929684a7a906d498 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Mon, 21 Oct 2024 12:20:24 -0400
Subject: [PATCH 04/12] Fixed PGT by including it in the loss function

---
 run_pgt_train.py                    |  62 +++++++-------
 ultralytics/cfg/pgt_train.yaml      | 127 ++++++++++++++++++++++++++++
 ultralytics/engine/pgt_trainer.py   |  10 ++-
 ultralytics/engine/pgt_validator.py |   2 +
 ultralytics/utils/loss.py           |  24 +++---
 5 files changed, 179 insertions(+), 46 deletions(-)
 create mode 100644 ultralytics/cfg/pgt_train.yaml

diff --git a/run_pgt_train.py b/run_pgt_train.py
index e308365d..aa431ca9 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -6,42 +6,42 @@ import argparse
 
 
 def main(args):
-  # model = YOLOv10()
+    # model = YOLOv10()
 
-  # If you want to finetune the model with pretrained weights, you could load the 
-  # pretrained weights like below
-  # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
-  # or
-  # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-  model = YOLOv10('yolov10n.pt', task='segment')
+    # If you want to finetune the model with pretrained weights, you could load the 
+    # pretrained weights like below
+    # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
+    # or
+    # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
+    model = YOLOv10('yolov10n.pt', task='segment')
 
-  args = dict(model='yolov10n.pt', data='coco.yaml', 
-              epochs=args.epochs, batch=args.batch_size,
-              # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
-              )
-  trainer = PGTSegmentationTrainer(overrides=args)
-  trainer.train(
-        # debug=True, 
-        #   args = dict(pgt_coeff=0.1), # Should add later to config
-          )
+    args = dict(model='yolov10n.pt', data='coco128-seg.yaml', 
+                epochs=args.epochs, batch=args.batch_size,
+                # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
+                )
+    trainer = PGTSegmentationTrainer(overrides=args)
+    trainer.train(
+                # debug=True, 
+                # args = dict(pgt_coeff=0.1), # Should add later to config
+                )
 
-  # Save the trained model
-  model.save('yolov10_coco_trained.pt')
+    # Save the trained model
+    model.save('yolov10_coco_trained.pt')
 
-  # Evaluate the model on the validation set
-  results = model.val(data='coco.yaml')
+    # Evaluate the model on the validation set
+    results = model.val(data='coco.yaml')
 
-  # Print the evaluation results
-  print(results)
+    # Print the evaluation results
+    print(results)
 
 if __name__ == "__main__":
-  parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
-  parser.add_argument('--device', type=str, default='0', help='CUDA device number')
-  parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
-  parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
-  args = parser.parse_args()
+    parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
+    parser.add_argument('--device', type=str, default='0', help='CUDA device number')
+    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
+    parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
+    args = parser.parse_args()
 
-  # Set CUDA device (only needed for multi-gpu machines)
-  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-  os.environ["CUDA_VISIBLE_DEVICES"] = args.device
-  main(args)
\ No newline at end of file
+    # Set CUDA device (only needed for multi-gpu machines)
+    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
+    main(args)
\ No newline at end of file
diff --git a/ultralytics/cfg/pgt_train.yaml b/ultralytics/cfg/pgt_train.yaml
new file mode 100644
index 00000000..bd074b10
--- /dev/null
+++ b/ultralytics/cfg/pgt_train.yaml
@@ -0,0 +1,127 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+# Default training settings and hyperparameters for medium-augmentation COCO training
+
+task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
+mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
+
+# Train settings -------------------------------------------------------------------------------------------------------
+model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
+data: # (str, optional) path to data file, i.e. coco128.yaml
+epochs: 100 # (int) number of epochs to train for
+time: # (float, optional) number of hours to train for, overrides epochs if supplied
+patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training
+batch: 16 # (int) number of images per batch (-1 for AutoBatch)
+imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
+save: True # (bool) save train checkpoints and predict results
+save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
+val_period: 1 # (int) Validation every x epochs
+cache: False # (bool) True/ram, disk or False. Use cache for data loading
+device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
+workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
+project: # (str, optional) project name
+name: # (str, optional) experiment name, results saved to 'project/name' directory
+exist_ok: False # (bool) whether to overwrite existing experiment
+pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
+optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
+verbose: True # (bool) whether to print verbose output
+seed: 0 # (int) random seed for reproducibility
+deterministic: True # (bool) whether to enable deterministic mode
+single_cls: False # (bool) train multi-class data as single-class
+rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
+cos_lr: False # (bool) use cosine learning rate scheduler
+close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable)
+resume: False # (bool) resume training from last checkpoint
+amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
+fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
+profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
+freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
+multi_scale: False # (bool) Whether to use multiscale during training
+# Segmentation
+overlap_mask: True # (bool) masks should overlap during training (segment train only)
+mask_ratio: 4 # (int) mask downsample ratio (segment train only)
+# Classification
+dropout: 0.0 # (float) use dropout regularization (classify train only)
+
+# Val/Test settings ----------------------------------------------------------------------------------------------------
+val: True # (bool) validate/test during training
+split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
+save_json: False # (bool) save results to JSON file
+save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
+conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
+iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
+max_det: 300 # (int) maximum number of detections per image
+half: False # (bool) use half precision (FP16)
+dnn: False # (bool) use OpenCV DNN for ONNX inference
+plots: True # (bool) save plots and images during train/val
+
+# Predict settings -----------------------------------------------------------------------------------------------------
+source: # (str, optional) source directory for images or videos
+vid_stride: 1 # (int) video frame-rate stride
+stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
+visualize: False # (bool) visualize model features
+augment: False # (bool) apply image augmentation to prediction sources
+agnostic_nms: False # (bool) class-agnostic NMS
+classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
+retina_masks: False # (bool) use high-resolution segmentation masks
+embed: # (list[int], optional) return feature vectors/embeddings from given layers
+
+# Visualize settings ---------------------------------------------------------------------------------------------------
+show: False # (bool) show predicted images and videos if environment allows
+save_frames: False # (bool) save predicted individual video frames
+save_txt: False # (bool) save results as .txt file
+save_conf: False # (bool) save results with confidence scores
+save_crop: False # (bool) save cropped images with results
+show_labels: True # (bool) show prediction labels, i.e. 'person'
+show_conf: True # (bool) show prediction confidence, i.e. '0.99'
+show_boxes: True # (bool) show prediction boxes
+line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None.
+
+# Export settings ------------------------------------------------------------------------------------------------------
+format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
+keras: False # (bool) use Kera=s
+optimize: False # (bool) TorchScript: optimize for mobile
+int8: False # (bool) CoreML/TF INT8 quantization
+dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
+simplify: False # (bool) ONNX: simplify model using `onnxslim`
+opset: # (int, optional) ONNX: opset version
+workspace: 4 # (int) TensorRT: workspace size (GB)
+nms: False # (bool) CoreML: add NMS
+
+# Hyperparameters ------------------------------------------------------------------------------------------------------
+lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
+lrf: 0.01 # (float) final learning rate (lr0 * lrf)
+momentum: 0.937 # (float) SGD momentum/Adam beta1
+weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
+warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
+warmup_momentum: 0.8 # (float) warmup initial momentum
+warmup_bias_lr: 0.1 # (float) warmup initial bias lr
+box: 7.5 # (float) box loss gain
+cls: 0.5 # (float) cls loss gain (scale with pixels)
+dfl: 1.5 # (float) dfl loss gain
+pose: 12.0 # (float) pose loss gain
+kobj: 1.0 # (float) keypoint obj loss gain
+label_smoothing: 0.0 # (float) label smoothing (fraction)
+nbs: 64 # (int) nominal batch size
+hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
+hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
+hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
+degrees: 0.0 # (float) image rotation (+/- deg)
+translate: 0.1 # (float) image translation (+/- fraction)
+scale: 0.5 # (float) image scale (+/- gain)
+shear: 0.0 # (float) image shear (+/- deg)
+perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
+flipud: 0.0 # (float) image flip up-down (probability)
+fliplr: 0.5 # (float) image flip left-right (probability)
+bgr: 0.0 # (float) image channel BGR (probability)
+mosaic: 1.0 # (float) image mosaic (probability)
+mixup: 0.0 # (float) image mixup (probability)
+copy_paste: 0.0 # (float) segment copy-paste (probability)
+auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
+erasing: 0.4 # (float) probability of random erasing during classification training (0-1)
+crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1)
+
+# Custom config.yaml ---------------------------------------------------------------------------------------------------
+cfg: # (str, optional) for overriding defaults.yaml
+
+# Tracker settings ------------------------------------------------------------------------------------------------------
+tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py
index 6b7c973e..559db7c2 100644
--- a/ultralytics/engine/pgt_trainer.py
+++ b/ultralytics/engine/pgt_trainer.py
@@ -380,11 +380,13 @@ class PGTTrainer:
                         )
                         if "momentum" in x:
                             x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
-
+                
                 # Forward
                 with torch.cuda.amp.autocast(self.amp):
                     batch = self.preprocess_batch(batch)
-                    (self.loss, self.loss_items), images = self.model(batch, return_images=True)
+                    batch['img'] = batch['img'].requires_grad_(True)
+                    self.loss, self.loss_items = self.model(batch)
+                    # (self.loss, self.loss_items), images = self.model(batch, return_images=True)
 
                 # smask = get_dist_reg(images, batch['masks'])
 
@@ -418,7 +420,7 @@ class PGTTrainer:
                             x1, y1, x2, y2 = bboxes[idx]
                             x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
                             mask[irx, :, y1:y2, x1:x2] = 1.0
-
+                    
                     save_imgs = True
                     if save_imgs:
                         # Convert tensors to numpy arrays
@@ -498,7 +500,7 @@ class PGTTrainer:
                     self.run_callbacks("on_batch_end")
                     if self.args.plots and ni in self.plot_idx:
                         self.plot_training_samples(batch, ni)
-
+                
                 self.run_callbacks("on_train_batch_end")
 
             self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
diff --git a/ultralytics/engine/pgt_validator.py b/ultralytics/engine/pgt_validator.py
index ce411ba9..b84631f5 100644
--- a/ultralytics/engine/pgt_validator.py
+++ b/ultralytics/engine/pgt_validator.py
@@ -175,12 +175,14 @@ class PGTValidator:
 
             # Inference
             with dt[1]:
+                model.zero_grad()
                 preds = model(batch["img"].requires_grad_(True), augment=augment)
 
             # Loss
             with dt[2]:
                 if self.training:
                     self.loss += model.loss(batch, preds)[1]
+                    model.zero_grad()
 
             # Postprocess
             with dt[3]:
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index e10b949f..f04ad3a1 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -731,7 +731,7 @@ class v10PGTDetectLoss:
         self.one2many = v8DetectionLoss(model, tal_topk=10)
         self.one2one = v8DetectionLoss(model, tal_topk=1)
     
-    def __call__(self, preds, batch):
+    def __call__(self, preds, batch, return_plaus=True):
         batch['img'] = batch['img'].requires_grad_(True)
         one2many = preds["one2many"]
         loss_one2many = self.one2many(one2many, batch)
@@ -739,16 +739,18 @@ class v10PGTDetectLoss:
         loss_one2one = self.one2one(one2one, batch)
 
         loss = loss_one2many[0] + loss_one2one[0]
-        
-        smask = get_dist_reg(batch['img'], batch['masks'])
+        if return_plaus:
+            smask = get_dist_reg(batch['img'], batch['masks'])
 
-        grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
-        grad = torch.abs(grad)
+            grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
+            grad = torch.abs(grad)
 
-        pgt_coeff = 3.0
-        plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
-        # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
-        loss += plaus_loss
-        
-        return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
+            pgt_coeff = 3.0
+            plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
+            # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
+            loss += plaus_loss
+            
+            return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
+        else:
+            return loss, torch.cat((loss_one2many[1], loss_one2one[1]))
     
\ No newline at end of file

From 5953d3c9c62fb634b014a9390908a6a26a65a9b5 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Mon, 21 Oct 2024 13:04:57 -0400
Subject: [PATCH 05/12] changed save path to not override previous training

---
 run_pgt_train.py | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index aa431ca9..9cb6df48 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -3,6 +3,7 @@ from ultralytics import YOLOv10, YOLO
 import os
 from ultralytics.models.yolo.segment import PGTSegmentationTrainer
 import argparse
+from datetime import datetime
 
 
 def main(args):
@@ -15,7 +16,7 @@ def main(args):
     # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
     model = YOLOv10('yolov10n.pt', task='segment')
 
-    args = dict(model='yolov10n.pt', data='coco128-seg.yaml', 
+    args = dict(model='yolov10n.pt', data=args.data_yaml, 
                 epochs=args.epochs, batch=args.batch_size,
                 # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
                 )
@@ -25,8 +26,16 @@ def main(args):
                 # args = dict(pgt_coeff=0.1), # Should add later to config
                 )
 
-    # Save the trained model
-    model.save('yolov10_coco_trained.pt')
+    # Create a directory to save model weights if it doesn't exist
+    model_weights_dir = 'model_weights'
+    if not os.path.exists(model_weights_dir):
+        os.makedirs(model_weights_dir)
+
+    # Save the trained model with a unique name based on the current date and time
+    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
+    data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
+    model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
+    model.save(model_save_path)  
 
     # Evaluate the model on the validation set
     results = model.val(data='coco.yaml')
@@ -39,6 +48,7 @@ if __name__ == "__main__":
     parser.add_argument('--device', type=str, default='0', help='CUDA device number')
     parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
     parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
+    parser.add_argument('--data_yaml', type=str, required=True, default='coco.yaml', help='Path to the data YAML file')
     args = parser.parse_args()
 
     # Set CUDA device (only needed for multi-gpu machines)

From 411157c18ab4f9ecad946a7850cffd32d8c65b3a Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Mon, 21 Oct 2024 13:05:25 -0400
Subject: [PATCH 06/12] fixed yaml default

---
 run_pgt_train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 9cb6df48..77d0a85d 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -48,7 +48,7 @@ if __name__ == "__main__":
     parser.add_argument('--device', type=str, default='0', help='CUDA device number')
     parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
     parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
-    parser.add_argument('--data_yaml', type=str, required=True, default='coco.yaml', help='Path to the data YAML file')
+    parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
     args = parser.parse_args()
 
     # Set CUDA device (only needed for multi-gpu machines)

From 2ccec65edcac6ae9fe2bf6249829dfa3d241aaf5 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Mon, 21 Oct 2024 13:56:29 -0400
Subject: [PATCH 07/12] fixed val and saving weights

---
 run_pgt_train.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 77d0a85d..39db22f6 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -14,7 +14,7 @@ def main(args):
     # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
     # or
     # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-    model = YOLOv10('yolov10n.pt', task='segment')
+    # model = YOLOv10('yolov10n.pt', task='segment')
 
     args = dict(model='yolov10n.pt', data=args.data_yaml, 
                 epochs=args.epochs, batch=args.batch_size,
@@ -35,10 +35,10 @@ def main(args):
     current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
     data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
     model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
-    model.save(model_save_path)  
+    trainer.save(model_save_path)  
 
     # Evaluate the model on the validation set
-    results = model.val(data='coco.yaml')
+    results = trainer.val(data=args.data_yaml)
 
     # Print the evaluation results
     print(results)

From 0efbfe7f4d0e473efcdf8d33996937b740b18932 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Tue, 22 Oct 2024 23:40:46 -0400
Subject: [PATCH 08/12] updated grad to create_graph during training, enabling
 better pgt_loss optimization

---
 run_pgt_train.py          | 17 +++++++++++------
 ultralytics/utils/loss.py | 30 +++++++++++++++++++++++-------
 2 files changed, 34 insertions(+), 13 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 77d0a85d..8c8b5487 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -5,6 +5,7 @@ from ultralytics.models.yolo.segment import PGTSegmentationTrainer
 import argparse
 from datetime import datetime
 
+# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 & 
 
 def main(args):
     # model = YOLOv10()
@@ -14,13 +15,16 @@ def main(args):
     # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
     # or
     # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-    model = YOLOv10('yolov10n.pt', task='segment')
+    # model = YOLOv10('yolov10n.pt', task='segment')
 
-    args = dict(model='yolov10n.pt', data=args.data_yaml, 
+    args_dict = dict(
+                model='yolov10n.pt', 
+                data=args.data_yaml, 
                 epochs=args.epochs, batch=args.batch_size,
+                # pgt_coeff=5.0,
                 # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
                 )
-    trainer = PGTSegmentationTrainer(overrides=args)
+    trainer = PGTSegmentationTrainer(overrides=args_dict)
     trainer.train(
                 # debug=True, 
                 # args = dict(pgt_coeff=0.1), # Should add later to config
@@ -35,10 +39,10 @@ def main(args):
     current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
     data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
     model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
-    model.save(model_save_path)  
+    trainer.model.save(model_save_path)  
 
     # Evaluate the model on the validation set
-    results = model.val(data='coco.yaml')
+    results = trainer.val(data=args.data_yaml)
 
     # Print the evaluation results
     print(results)
@@ -54,4 +58,5 @@ if __name__ == "__main__":
     # Set CUDA device (only needed for multi-gpu machines)
     os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
     os.environ["CUDA_VISIBLE_DEVICES"] = args.device
-    main(args)
\ No newline at end of file
+    main(args)
+    
\ No newline at end of file
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index f04ad3a1..fc3a404a 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -727,11 +727,12 @@ class v10DetectLoss:
         return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
 
 class v10PGTDetectLoss:
-    def __init__(self, model):
+    def __init__(self, model, pgt_coeff=3.0):
         self.one2many = v8DetectionLoss(model, tal_topk=10)
         self.one2one = v8DetectionLoss(model, tal_topk=1)
+        self.pgt_coeff = pgt_coeff
     
-    def __call__(self, preds, batch, return_plaus=True):
+    def __call__(self, preds, batch, return_plaus=True, inference=False):
         batch['img'] = batch['img'].requires_grad_(True)
         one2many = preds["one2many"]
         loss_one2many = self.one2many(one2many, batch)
@@ -740,13 +741,28 @@ class v10PGTDetectLoss:
 
         loss = loss_one2many[0] + loss_one2one[0]
         if return_plaus:
-            smask = get_dist_reg(batch['img'], batch['masks'])
+            smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
 
-            grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
-            grad = torch.abs(grad)
+            # graph = False if inference else True
+            # grad = torch.autograd.grad(loss, batch['img'], 
+            #                            retain_graph=True, 
+            #                            create_graph=graph,
+            #                            )[0]
+            try:
+                grad = torch.autograd.grad(loss, batch['img'], 
+                                           retain_graph=True, 
+                                           create_graph=True,
+                                           )[0]
+            except:
+                grad = torch.autograd.grad(loss, batch['img'], 
+                                           retain_graph=True, 
+                                           create_graph=False,
+                                           )[0]        
 
-            pgt_coeff = 3.0
-            plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
+
+            grad = grad ** 2
+
+            plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
             # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
             loss += plaus_loss
             

From 9524af3bfefaba6978fe0de87ab14fdc41b50480 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Wed, 23 Oct 2024 20:28:51 -0400
Subject: [PATCH 09/12] Model loading using YOLOv10PGT added, and pgt_coeff is
 now a cfg parameter

---
 run_pgt_train.py                             | 46 ++++++++++++--------
 ultralytics/__init__.py                      |  5 ++-
 ultralytics/cfg/default.yaml                 |  1 +
 ultralytics/cfg/pgt_train.yaml               |  1 +
 ultralytics/engine/model.py                  |  8 +++-
 ultralytics/models/__init__.py               |  4 +-
 ultralytics/models/yolo/segment/pgt_train.py | 13 +++++-
 ultralytics/models/yolov10/__init__.py       |  4 +-
 ultralytics/models/yolov10/model.py          | 33 ++++++++++++++
 ultralytics/nn/tasks.py                      |  2 +-
 ultralytics/utils/loss.py                    | 14 +++---
 ultralytics/utils/plaus_functs.py            | 28 +++++++++++-
 12 files changed, 121 insertions(+), 38 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 983b9888..6a8f3ab8 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -1,35 +1,43 @@
-from ultralytics import YOLOv10, YOLO
+from ultralytics import YOLOv10, YOLO, YOLOv10PGT
 # from ultralytics.engine.pgt_trainer import PGTTrainer
 import os
 from ultralytics.models.yolo.segment import PGTSegmentationTrainer
 import argparse
 from datetime import datetime
+import torch
 
 # nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 & 
 
 def main(args):
-    # model = YOLOv10()
-
+    model = YOLOv10PGT('yolov10n.pt')
+    model.train(    
+                data=args.data_yaml, 
+                epochs=args.epochs, 
+                batch=args.batch_size,
+                # amp=False,
+                # pgt_coeff=1.5,
+                # cfg='pgt_train.yaml',  # Load and train model with the config file
+                )
     # If you want to finetune the model with pretrained weights, you could load the 
-    # pretrained weights like below
+    # pretrained weights like below 
     # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
     # or
     # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
     # model = YOLOv10('yolov10n.pt', task='segment')
     # model = YOLOv10('yolov10n.pt', task='segment')
 
-    args_dict = dict(
-                model='yolov10n.pt', 
-                data=args.data_yaml, 
-                epochs=args.epochs, batch=args.batch_size,
-                # pgt_coeff=5.0,
-                # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
-                )
-    trainer = PGTSegmentationTrainer(overrides=args_dict)
-    trainer.train(
-                # debug=True, 
-                # args = dict(pgt_coeff=0.1), # Should add later to config
-                )
+    # args_dict = dict(
+    #             model='yolov10n.pt', 
+    #             data=args.data_yaml, 
+    #             epochs=args.epochs, batch=args.batch_size,
+    #             # pgt_coeff=5.0,
+    #             # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
+    #             )
+    # trainer = PGTSegmentationTrainer(overrides=args_dict)
+    # trainer.train(
+    #             # debug=True, 
+    #             # args = dict(pgt_coeff=0.1), # Should add later to config
+    #             )
 
     # Create a directory to save model weights if it doesn't exist
     model_weights_dir = 'model_weights'
@@ -40,10 +48,12 @@ def main(args):
     current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
     data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
     model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
-    trainer.model.save(model_save_path)  
+    model.save(model_save_path)
+    # torch.save(trainer.model.state_dict(), model_save_path)
+    
 
     # Evaluate the model on the validation set
-    results = trainer.val(data=args.data_yaml)
+    results = model.val(data=args.data_yaml)
 
     # Print the evaluation results
     print(results)
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 8ff1b4fb..805308c4 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -3,7 +3,7 @@
 __version__ = "8.1.34"
 
 from ultralytics.data.explorer.explorer import Explorer
-from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10
+from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10, YOLOv10PGT
 from ultralytics.models.fastsam import FastSAM
 from ultralytics.models.nas import NAS
 from ultralytics.utils import ASSETS, SETTINGS as settings
@@ -23,5 +23,6 @@ __all__ = (
     "download",
     "settings",
     "Explorer",
-    "YOLOv10"
+    "YOLOv10",
+    "YOLOv10PGT",
 )
diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml
index bd074b10..8eb0cdd7 100644
--- a/ultralytics/cfg/default.yaml
+++ b/ultralytics/cfg/default.yaml
@@ -41,6 +41,7 @@ overlap_mask: True # (bool) masks should overlap during training (segment train
 mask_ratio: 4 # (int) mask downsample ratio (segment train only)
 # Classification
 dropout: 0.0 # (float) use dropout regularization (classify train only)
+pgt_coeff: 2.0 # (float) PGT loss coefficient 
 
 # Val/Test settings ----------------------------------------------------------------------------------------------------
 val: True # (bool) validate/test during training
diff --git a/ultralytics/cfg/pgt_train.yaml b/ultralytics/cfg/pgt_train.yaml
index bd074b10..9511a88d 100644
--- a/ultralytics/cfg/pgt_train.yaml
+++ b/ultralytics/cfg/pgt_train.yaml
@@ -41,6 +41,7 @@ overlap_mask: True # (bool) masks should overlap during training (segment train
 mask_ratio: 4 # (int) mask downsample ratio (segment train only)
 # Classification
 dropout: 0.0 # (float) use dropout regularization (classify train only)
+pgt_coeff: 1.0 # (float) PGT loss coefficient 
 
 # Val/Test settings ----------------------------------------------------------------------------------------------------
 val: True # (bool) validate/test during training
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index 53cb1f40..fbb497da 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -668,9 +668,10 @@ class Model(nn.Module):
             self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
         return self.metrics
 
-    def train_pgt(
+    def train_pgt( # Currently unused, but should be considered if changes to the train function are made
         self,
         trainer=None,
+        debug=False,
         **kwargs,
     ):
         """
@@ -733,7 +734,10 @@ class Model(nn.Module):
                     pass
 
         self.trainer.hub_session = self.session  # attach optional HUB session
-        self.trainer.train()
+        if debug:
+            self.trainer.train(debug=debug)
+        else:
+            self.trainer.train()
         # Update model and cfg after training
         if RANK in (-1, 0):
             ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py
index 42de3fba..efa4b486 100644
--- a/ultralytics/models/__init__.py
+++ b/ultralytics/models/__init__.py
@@ -3,6 +3,6 @@
 from .rtdetr import RTDETR
 from .sam import SAM
 from .yolo import YOLO, YOLOWorld
-from .yolov10 import YOLOv10
+from .yolov10 import YOLOv10, YOLOv10PGT
 
-__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10"  # allow simpler import
+__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10", "YOLOv10PGT"  # allow simpler import
diff --git a/ultralytics/models/yolo/segment/pgt_train.py b/ultralytics/models/yolo/segment/pgt_train.py
index 7e7712b9..aa0559ee 100644
--- a/ultralytics/models/yolo/segment/pgt_train.py
+++ b/ultralytics/models/yolo/segment/pgt_train.py
@@ -5,9 +5,18 @@ from copy import copy
 from ultralytics.models import yolo
 from ultralytics.nn.tasks import SegmentationModel, DetectionModel
 from ultralytics.utils import DEFAULT_CFG, RANK
+# from ultralytics.utils import yaml_load, IterableSimpleNamespace, ROOT
 from ultralytics.utils.plotting import plot_images, plot_results
-from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
-from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
+from ultralytics.models.yolov10.model import YOLOv10PGTDetectionModel
+from ultralytics.models.yolov10.val import YOLOv10PGTDetectionValidator
+
+# # Default configuration
+# DEFAULT_CFG_DICT = yaml_load(ROOT / "cfg/pgt_train.yaml")
+# for k, v in DEFAULT_CFG_DICT.items():
+#     if isinstance(v, str) and v.lower() == "none":
+#         DEFAULT_CFG_DICT[k] = None
+# DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
+# DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
 
 class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
     """
diff --git a/ultralytics/models/yolov10/__init__.py b/ultralytics/models/yolov10/__init__.py
index 97f137f9..4da9ec2d 100644
--- a/ultralytics/models/yolov10/__init__.py
+++ b/ultralytics/models/yolov10/__init__.py
@@ -1,5 +1,5 @@
-from .model import YOLOv10
+from .model import YOLOv10, YOLOv10PGT
 from .predict import YOLOv10DetectionPredictor
 from .val import YOLOv10DetectionValidator
 
-__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10"
+__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10", "YOLOv10PGT"
diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py
index 0c8bcb9a..e746bd8b 100644
--- a/ultralytics/models/yolov10/model.py
+++ b/ultralytics/models/yolov10/model.py
@@ -4,6 +4,7 @@ from .val import YOLOv10DetectionValidator
 from .predict import YOLOv10DetectionPredictor
 from .train import YOLOv10DetectionTrainer
 from .pgt_train import YOLOv10PGTDetectionTrainer
+# from ..yolo.segment import PGTSegmentationTrainer
 # from .pgt_trainer import YOLOv10DetectionTrainer
 
 from huggingface_hub import PyTorchModelHubMixin
@@ -36,4 +37,36 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
                 "validator": YOLOv10DetectionValidator,
                 "predictor": YOLOv10DetectionPredictor,
             },
+        }
+
+def _get_pgt_segmentation_trainer():
+    from ..yolo.segment import PGTSegmentationTrainer
+    return PGTSegmentationTrainer
+
+class YOLOv10PGT(Model, PyTorchModelHubMixin, model_card_template=card_template_text):
+
+    def __init__(self, model="yolov10n.pt", task=None, verbose=False, 
+                 names=None):
+        super().__init__(model=model, task=task, verbose=verbose)
+        if names is not None:
+            setattr(self.model, 'names', names)
+
+    def push_to_hub(self, repo_name, **kwargs):
+        config = kwargs.get('config', {})
+        config['names'] = self.names
+        config['model'] = self.model.yaml['yaml_file']
+        config['task'] = self.task
+        kwargs['config'] = config
+        super().push_to_hub(repo_name, **kwargs)
+
+    @property
+    def task_map(self):
+        """Map head to model, trainer, validator, and predictor classes."""
+        return {
+            "detect": {
+                "model": YOLOv10DetectionModel,
+                "trainer": _get_pgt_segmentation_trainer(),
+                "validator": YOLOv10DetectionValidator,
+                "predictor": YOLOv10DetectionPredictor,
+            },
         }
\ No newline at end of file
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 48242826..60b3d27e 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -654,7 +654,7 @@ class YOLOv10DetectionModel(DetectionModel):
     
 class YOLOv10PGTDetectionModel(DetectionModel):
     def init_criterion(self):
-        return v10PGTDetectLoss(self)
+        return v10PGTDetectLoss(self, pgt_coeff=self.args.pgt_coeff if hasattr(self.args, 'pgt_coeff') else None)
 
 class Ensemble(nn.ModuleList):
     """Ensemble of models."""
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index fc3a404a..e7e1aca5 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -727,12 +727,14 @@ class v10DetectLoss:
         return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
 
 class v10PGTDetectLoss:
-    def __init__(self, model, pgt_coeff=3.0):
+    def __init__(self, model, pgt_coeff):
         self.one2many = v8DetectionLoss(model, tal_topk=10)
         self.one2one = v8DetectionLoss(model, tal_topk=1)
-        self.pgt_coeff = pgt_coeff
+        self.pgt_coeff = pgt_coeff if pgt_coeff is not None else 2.0
     
-    def __call__(self, preds, batch, return_plaus=True, inference=False):
+    def __call__(self, preds, batch, return_plaus=True, pgt_coeff=None):
+        if pgt_coeff is not None:
+            self.pgt_coeff = pgt_coeff
         batch['img'] = batch['img'].requires_grad_(True)
         one2many = preds["one2many"]
         loss_one2many = self.one2many(one2many, batch)
@@ -743,11 +745,6 @@ class v10PGTDetectLoss:
         if return_plaus:
             smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
 
-            # graph = False if inference else True
-            # grad = torch.autograd.grad(loss, batch['img'], 
-            #                            retain_graph=True, 
-            #                            create_graph=graph,
-            #                            )[0]
             try:
                 grad = torch.autograd.grad(loss, batch['img'], 
                                            retain_graph=True, 
@@ -764,6 +761,7 @@ class v10PGTDetectLoss:
 
             plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
             # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
+            
             loss += plaus_loss
             
             return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
diff --git a/ultralytics/utils/plaus_functs.py b/ultralytics/utils/plaus_functs.py
index 3c5b689a..96fb8b0b 100644
--- a/ultralytics/utils/plaus_functs.py
+++ b/ultralytics/utils/plaus_functs.py
@@ -39,8 +39,11 @@ def get_dist_reg(images, seg_mask):
             kernel_size += 1
         seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
         if torch.max(seg_mask1) > 1.0:
-            seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
+            # seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
+            seg_mask1 = normalize_tensor(seg_mask1)
         smask = torch.max(smask, seg_mask1)
+        
+    smask = normalize_tensor(smask)
     return smask
 
 def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
@@ -374,6 +377,29 @@ def normalize_batch(x):
     
     return x_
 
+def normalize_batch_nonan(x):
+    """
+    Normalize a batch of tensors along each channel.
+    
+    Args:
+        x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
+        
+    Returns:
+        torch.Tensor: Normalized tensor of the same shape as the input.
+    """
+    mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
+    maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
+    x_ = torch.zeros_like(x)
+    for i in range(x.shape[0]):
+        if torch.all(x[i] == 0):
+            x_[i] = x[i]
+        else:
+            mins[i] = x[i].min()
+            maxs[i] = x[i].max()
+            x_[i] = (x[i] - mins[i]) / (maxs[i] - mins[i])
+    
+    return x_
+
 def get_detections(model_clone, img):
     """
     Get detections from a model given an input image and targets.

From 44c912647e507a49580ed5288b7c990688adebea Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Wed, 23 Oct 2024 20:33:28 -0400
Subject: [PATCH 10/12] added working validation

---
 run_val.py                                   | 90 ++++++--------------
 ultralytics/models/yolo/segment/pgt_train.py | 22 ++---
 2 files changed, 32 insertions(+), 80 deletions(-)

diff --git a/run_val.py b/run_val.py
index 98c5a476..40d5c499 100644
--- a/run_val.py
+++ b/run_val.py
@@ -1,72 +1,32 @@
-from ultralytics import YOLOv10, YOLO
+from ultralytics import YOLOv10, YOLO, YOLOv10PGT
 # from ultralytics.engine.pgt_trainer import PGTTrainer
-# from ultralytics import BaseTrainer
-# from ultralytics.engine.trainer import BaseTrainer
 import os
+from ultralytics.models.yolo.segment import PGTSegmentationTrainer
+import argparse
+from datetime import datetime
 
-# Set CUDA device (only needed for multi-gpu machines) 
-os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
-os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
+# nohup python run_pgt_train.py --device 1 > ./output_logs/gpu1_yolov10_pgt_train.log 2>&1 & 
 
-# model = YOLOv10()
-# model = YOLO()
-# If you want to finetune the model with pretrained weights, you could load the 
-# pretrained weights like below
-# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
-# or
-# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
-model = YOLOv10('yolov10n.pt')
+def main(args):
 
-# Evaluate the model on the validation set
-results = model.val(data='coco.yaml')
+    model = YOLOv10PGT(args.model_path)
+    
+    # Evaluate the model on the validation set
+    results = model.val(data=args.data_yaml)
+    
+    # Print the evaluation results
+    print(results)
 
-# Print the evaluation results
-print(results)
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
+    parser.add_argument('--device', type=str, default='1', help='CUDA device number')
+    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
+    parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
+    parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
+    parser.add_argument('--model_path', type=str, default='yolov10n.pt', help='Path to the model file')
+    args = parser.parse_args()
 
-# pred = results[0].boxes[0].conf
-
-# # Hook to store the activations
-# activations = {}
-
-# def get_activation(name):
-#     def hook(model, input, output):
-#         activations[name] = output
-#     return hook
-
-# # Register hooks for each layer you want to inspect
-# for name, layer in model.model.named_modules():
-#     layer.register_forward_hook(get_activation(name))
-
-# # Run the model to get activations
-# results = model.predict(image_tensor, save=True, visualize=True)
-
-# # # Print the activations
-# # for name, activation in activations.items():
-# #     print(f"Activation from layer {name}: {activation}")
-
-# # List activation names separately
-# print("\nActivation layer names:")
-# for name in activations.keys():
-#     print(name)
-# # pred.backward()
-
-# # Assuming 'model.23' is the layer of interest for bbox prediction and confidence
-# activation = activations['model.23']['one2one'][0]
-# act_23 = activations['model.23.cv3.2']
-# act_dfl = activations['model.23.dfl.conv']
-# act_conv = activations['model.0.conv']
-# act_act = activations['model.0.act']
-
-# # with torch.autograd.set_detect_anomaly(True):
-# #     pred.backward()
-# grad = torch.autograd.grad(act_23, im, grad_outputs=torch.ones_like(act_23), create_graph=True, retain_graph=True)[0]
-# # grad = torch.autograd.grad(pred, im, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
-# grad = torch.autograd.grad(activations['model.23']['one2one'][1][0], 
-#                            activations['model.23.one2one_cv3.2'], 
-#                            grad_outputs=torch.ones_like(activations['model.23']['one2one'][1][0]), 
-#                            create_graph=True, retain_graph=True)[0]
-
-# # Print the results
-# print(results)
-
-# model.val(data='coco.yaml', batch=256)
\ No newline at end of file
+    # Set CUDA device (only needed for multi-gpu machines)
+    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
+    main(args)
\ No newline at end of file
diff --git a/ultralytics/models/yolo/segment/pgt_train.py b/ultralytics/models/yolo/segment/pgt_train.py
index aa0559ee..6d3b4150 100644
--- a/ultralytics/models/yolo/segment/pgt_train.py
+++ b/ultralytics/models/yolo/segment/pgt_train.py
@@ -41,10 +41,8 @@ class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
 
     def get_model(self, cfg=None, weights=None, verbose=True):
         """Return SegmentationModel initialized with specified config and weights."""
-        if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
-            model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
-        else:
-            model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+
+        model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
         if weights:
             model.load(weights)
 
@@ -52,17 +50,11 @@ class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
 
     def get_validator(self):
         """Return an instance of SegmentationValidator for validation of YOLO model."""
-        
-        if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
-            self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
-            return YOLOv10PGTDetectionValidator(
-                self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
-            )
-        else:
-            self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
-            return yolo.segment.SegmentationValidator(
-                self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
-            )
+
+        self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
+        return YOLOv10PGTDetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
 
     def plot_training_samples(self, batch, ni):
         """Creates a plot of training sample images with labels and box coordinates."""

From 21e660fde9f09ad1fffe17c50e2a9b9db8109bd8 Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Fri, 1 Nov 2024 11:04:12 -0400
Subject: [PATCH 11/12] Fixed issues with pgt training, adding the pgt loss to
 the YOLO loss function. PGT loss can now be tracked during training.

---
 run_pgt_train.py                  |   7 +-
 ultralytics/engine/pgt_trainer.py | 127 ++++++++++++++-------------
 ultralytics/utils/plaus_functs.py |   4 +-
 ultralytics/utils/plotting.py     | 138 ++++++++++++++++++++++++++++++
 4 files changed, 211 insertions(+), 65 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 6a8f3ab8..9bb8aae6 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -6,7 +6,7 @@ import argparse
 from datetime import datetime
 import torch
 
-# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 & 
+# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 & 
 
 def main(args):
     model = YOLOv10PGT('yolov10n.pt')
@@ -15,7 +15,7 @@ def main(args):
                 epochs=args.epochs, 
                 batch=args.batch_size,
                 # amp=False,
-                # pgt_coeff=1.5,
+                pgt_coeff=3.0,
                 # cfg='pgt_train.yaml',  # Load and train model with the config file
                 )
     # If you want to finetune the model with pretrained weights, you could load the 
@@ -51,7 +51,6 @@ def main(args):
     model.save(model_save_path)
     # torch.save(trainer.model.state_dict(), model_save_path)
     
-
     # Evaluate the model on the validation set
     results = model.val(data=args.data_yaml)
 
@@ -61,7 +60,7 @@ def main(args):
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
     parser.add_argument('--device', type=str, default='0', help='CUDA device number')
-    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
+    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
     parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
     parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
     args = parser.parse_args()
diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py
index 559db7c2..36a0276d 100644
--- a/ultralytics/engine/pgt_trainer.py
+++ b/ultralytics/engine/pgt_trainer.py
@@ -401,67 +401,9 @@ class PGTTrainer:
                 debug_ = debug
                 if debug_ and (i % 25 == 0):
                     debug_ = False
-                    # Create a tensor of zeros with the same size as images
-                    mask = torch.zeros_like(images, dtype=torch.float32)
-                    smask = get_dist_reg(images, batch['masks'])
-                    grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
-                    grad = torch.abs(grad)
 
-                    batch_size = images.shape[0]
-                    imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
-                    targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
-                    targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
-                    gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
-                    mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+                    plot_grads(batch, self, i)
                     
-                    # Iterate over each bounding box and set the corresponding pixels to 1
-                    for irx, bboxes in enumerate(gt_bboxes):
-                        for idx in range(len(bboxes)):
-                            x1, y1, x2, y2 = bboxes[idx]
-                            x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
-                            mask[irx, :, y1:y2, x1:x2] = 1.0
-                    
-                    save_imgs = True
-                    if save_imgs:
-                        # Convert tensors to numpy arrays
-                        images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
-                        grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
-                        mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
-                        seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
-
-                        # Normalize grad for visualization
-                        grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
-                        
-                        for ix in range(images_np.shape[0]):
-                            fig, ax = plt.subplots(1, 6, figsize=(30, 5))
-                            ax[0].imshow(images_np[ix])
-                            ax[0].set_title('Image')
-                            ax[1].imshow(grad_np[ix], cmap='jet')
-                            ax[1].set_title('Gradient')
-                            ax[2].imshow(images_np[ix])
-                            ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
-                            ax[2].set_title('Overlay')
-                            ax[3].imshow(mask_np[ix], cmap='gray')
-                            ax[3].set_title('Mask')
-                            ax[4].imshow(seg_mask_np[ix], cmap='gray')
-                            ax[4].set_title('Segmentation Mask')
-                            
-                            # Plot image with bounding boxes
-                            ax[5].imshow(images_np[ix])
-                            for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
-                                x1, y1, x2, y2 = bbox
-                                x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
-                                rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
-                                ax[5].add_patch(rect)
-                                ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
-                            ax[5].set_title('Bounding Boxes')
-
-                            save_dir_attr = f"figures/attributions/run{self.num}"
-                            if not os.path.exists(save_dir_attr):
-                                os.makedirs(save_dir_attr)
-                            plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
-                            plt.close(fig)
-                    images = images.detach()
                 
                 
                 if RANK != -1:
@@ -500,6 +442,7 @@ class PGTTrainer:
                     self.run_callbacks("on_batch_end")
                     if self.args.plots and ni in self.plot_idx:
                         self.plot_training_samples(batch, ni)
+                        plot_grads(batch, self, ni)
                 
                 self.run_callbacks("on_train_batch_end")
 
@@ -839,3 +782,69 @@ class PGTTrainer:
             f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
         )
         return optimizer
+
+
+def plot_grads(batch, obj, i, nsamples=16):
+    # Create a tensor of zeros with the same size as images
+    images = batch['img'].requires_grad_(True)
+    mask = torch.zeros_like(images, dtype=torch.float32)
+    smask = get_dist_reg(images, batch['masks'])
+    loss, loss_items = obj.model(batch)
+    grad = torch.autograd.grad(loss, images, retain_graph=True)[0]
+    grad = torch.abs(grad)
+
+    batch_size = images.shape[0]
+    imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device)
+    targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+    targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+    gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+    mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+    
+    # Iterate over each bounding box and set the corresponding pixels to 1
+    for irx, bboxes in enumerate(gt_bboxes):
+        for idx in range(len(bboxes)):
+            x1, y1, x2, y2 = bboxes[idx]
+            x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
+            mask[irx, :, y1:y2, x1:x2] = 1.0
+    
+    save_imgs = True
+    if save_imgs:
+        # Convert tensors to numpy arrays
+        images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
+        grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
+        mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
+        seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
+
+        # Normalize grad for visualization
+        grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
+        
+        range_val = min(nsamples, images_np.shape[0])
+        for ix in range(range_val):
+            fig, ax = plt.subplots(1, 6, figsize=(30, 5))
+            ax[0].imshow(images_np[ix])
+            ax[0].set_title('Image')
+            ax[1].imshow(grad_np[ix], cmap='jet')
+            ax[1].set_title('Gradient')
+            ax[2].imshow(images_np[ix])
+            ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
+            ax[2].set_title('Overlay')
+            ax[3].imshow(mask_np[ix], cmap='gray')
+            ax[3].set_title('Mask')
+            ax[4].imshow(seg_mask_np[ix], cmap='gray')
+            ax[4].set_title('Segmentation Mask')
+            
+            # Plot image with bounding boxes
+            ax[5].imshow(images_np[ix])
+            for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
+                x1, y1, x2, y2 = bbox
+                x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
+                rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
+                ax[5].add_patch(rect)
+                ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
+            ax[5].set_title('Bounding Boxes')
+
+            save_dir_attr = f"{obj.save_dir._str}/attributions"
+            if not os.path.exists(save_dir_attr):
+                os.makedirs(save_dir_attr)
+            plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png')
+            plt.close(fig)
\ No newline at end of file
diff --git a/ultralytics/utils/plaus_functs.py b/ultralytics/utils/plaus_functs.py
index 96fb8b0b..2ca60ba9 100644
--- a/ultralytics/utils/plaus_functs.py
+++ b/ultralytics/utils/plaus_functs.py
@@ -11,7 +11,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppr
 from .metrics import bbox_iou
 import torchvision.transforms as T
 
-def plaus_loss_fn(grad, smask, pgt_coeff):
+def plaus_loss_fn(grad, smask, pgt_coeff, square=True):
     ################## Compute the PGT Loss ##################
     # Positive regularization term for incentivizing pixels near the target to have high attribution
     dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
@@ -22,7 +22,7 @@ def plaus_loss_fn(grad, smask, pgt_coeff):
     dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
     plaus_reg = (((1.0 + dist_reg) / 2.0))
     # Calculate plausibility loss
-    plaus_loss = (1 - plaus_reg) * pgt_coeff
+    plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff
     return plaus_loss
 
 def get_dist_reg(images, seg_mask):
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 0e99c722..e94b8e63 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -837,6 +837,144 @@ def plot_images(
     if on_plot:
         on_plot(fname)
 
+@threaded
+def plot_gradients(
+    images,
+    batch_idx,
+    cls,
+    bboxes=np.zeros(0, dtype=np.float32),
+    confs=None,
+    masks=np.zeros(0, dtype=np.uint8),
+    kpts=np.zeros((0, 51), dtype=np.float32),
+    paths=None,
+    fname="images.jpg",
+    names=None,
+    on_plot=None,
+    max_subplots=16,
+    save=True,
+    conf_thres=0.25,
+):
+    """Plot image grid with labels."""
+    if isinstance(images, torch.Tensor):
+        images = images.detach().cpu().float().numpy()
+    if isinstance(cls, torch.Tensor):
+        cls = cls.cpu().numpy()
+    if isinstance(bboxes, torch.Tensor):
+        bboxes = bboxes.cpu().numpy()
+    if isinstance(masks, torch.Tensor):
+        masks = masks.cpu().numpy().astype(int)
+    if isinstance(kpts, torch.Tensor):
+        kpts = kpts.cpu().numpy()
+    if isinstance(batch_idx, torch.Tensor):
+        batch_idx = batch_idx.cpu().numpy()
+
+    max_size = 1920  # max image size
+    bs, _, h, w = images.shape  # batch size, _, height, width
+    bs = min(bs, max_subplots)  # limit plot images
+    ns = np.ceil(bs**0.5)  # number of subplots (square)
+    if np.max(images[0]) <= 1:
+        images *= 255  # de-normalise (optional)
+
+    # Build Image
+    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
+    for i in range(bs):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
+
+    # Resize (optional)
+    scale = max_size / ns / max(h, w)
+    if scale < 1:
+        h = math.ceil(scale * h)
+        w = math.ceil(scale * w)
+        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
+
+    # Annotate
+    fs = int((h + w) * ns * 0.01)  # font size
+    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+    for i in range(bs):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
+        if paths:
+            annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
+        if len(cls) > 0:
+            idx = batch_idx == i
+            classes = cls[idx].astype("int")
+            labels = confs is None
+
+            if len(bboxes):
+                boxes = bboxes[idx]
+                conf = confs[idx] if confs is not None else None  # check for confidence presence (label vs pred)
+                is_obb = boxes.shape[-1] == 5  # xywhr
+                boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
+                if len(boxes):
+                    if boxes[:, :4].max() <= 1.1:  # if normalized with tolerance 0.1
+                        boxes[..., 0::2] *= w  # scale to pixels
+                        boxes[..., 1::2] *= h
+                    elif scale < 1:  # absolute coords need scale if image scales
+                        boxes[..., :4] *= scale
+                boxes[..., 0::2] += x
+                boxes[..., 1::2] += y
+                for j, box in enumerate(boxes.astype(np.int64).tolist()):
+                    c = classes[j]
+                    color = colors(c)
+                    c = names.get(c, c) if names else c
+                    if labels or conf[j] > conf_thres:
+                        label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
+                        annotator.box_label(box, label, color=color, rotated=is_obb)
+
+            elif len(classes):
+                for c in classes:
+                    color = colors(c)
+                    c = names.get(c, c) if names else c
+                    annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
+
+            # Plot keypoints
+            if len(kpts):
+                kpts_ = kpts[idx].copy()
+                if len(kpts_):
+                    if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01:  # if normalized with tolerance .01
+                        kpts_[..., 0] *= w  # scale to pixels
+                        kpts_[..., 1] *= h
+                    elif scale < 1:  # absolute coords need scale if image scales
+                        kpts_ *= scale
+                kpts_[..., 0] += x
+                kpts_[..., 1] += y
+                for j in range(len(kpts_)):
+                    if labels or conf[j] > conf_thres:
+                        annotator.kpts(kpts_[j])
+
+            # Plot masks
+            if len(masks):
+                if idx.shape[0] == masks.shape[0]:  # overlap_masks=False
+                    image_masks = masks[idx]
+                else:  # overlap_masks=True
+                    image_masks = masks[[i]]  # (1, 640, 640)
+                    nl = idx.sum()
+                    index = np.arange(nl).reshape((nl, 1, 1)) + 1
+                    image_masks = np.repeat(image_masks, nl, axis=0)
+                    image_masks = np.where(image_masks == index, 1.0, 0.0)
+
+                im = np.asarray(annotator.im).copy()
+                for j in range(len(image_masks)):
+                    if labels or conf[j] > conf_thres:
+                        color = colors(classes[j])
+                        mh, mw = image_masks[j].shape
+                        if mh != h or mw != w:
+                            mask = image_masks[j].astype(np.uint8)
+                            mask = cv2.resize(mask, (w, h))
+                            mask = mask.astype(bool)
+                        else:
+                            mask = image_masks[j].astype(bool)
+                        with contextlib.suppress(Exception):
+                            im[y : y + h, x : x + w, :][mask] = (
+                                im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
+                            )
+                annotator.fromarray(im)
+    if not save:
+        return np.asarray(annotator.im)
+    annotator.im.save(fname)  # save
+    if on_plot:
+        on_plot(fname)
 
 @plt_settings()
 def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):

From 1ef6cbcec52d28b0a9fe67276115d4ff45ddffed Mon Sep 17 00:00:00 2001
From: nielseni6 <nielseni6@students.rowan.edu>
Date: Fri, 1 Nov 2024 11:19:22 -0400
Subject: [PATCH 12/12] added pgt_coeff to argparser

---
 run_pgt_train.py | 35 +++++++++++++----------------------
 1 file changed, 13 insertions(+), 22 deletions(-)

diff --git a/run_pgt_train.py b/run_pgt_train.py
index 9bb8aae6..a89763a4 100644
--- a/run_pgt_train.py
+++ b/run_pgt_train.py
@@ -10,34 +10,24 @@ import torch
 
 def main(args):
     model = YOLOv10PGT('yolov10n.pt')
-    model.train(    
-                data=args.data_yaml, 
-                epochs=args.epochs, 
-                batch=args.batch_size,
-                # amp=False,
-                pgt_coeff=3.0,
-                # cfg='pgt_train.yaml',  # Load and train model with the config file
-                )
+
+    if args.pgt_coeff is None:
+        model.train(data=args.data_yaml, epochs=args.epochs, batch=args.batch_size)
+    else:
+        model.train(    
+                    data=args.data_yaml, 
+                    epochs=args.epochs, 
+                    batch=args.batch_size,
+                    # amp=False,
+                    pgt_coeff=args.pgt_coeff,
+                    # cfg='pgt_train.yaml',  # Load and train model with the config file
+                    )
     # If you want to finetune the model with pretrained weights, you could load the 
     # pretrained weights like below 
     # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
     # or
     # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
     # model = YOLOv10('yolov10n.pt', task='segment')
-    # model = YOLOv10('yolov10n.pt', task='segment')
-
-    # args_dict = dict(
-    #             model='yolov10n.pt', 
-    #             data=args.data_yaml, 
-    #             epochs=args.epochs, batch=args.batch_size,
-    #             # pgt_coeff=5.0,
-    #             # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
-    #             )
-    # trainer = PGTSegmentationTrainer(overrides=args_dict)
-    # trainer.train(
-    #             # debug=True, 
-    #             # args = dict(pgt_coeff=0.1), # Should add later to config
-    #             )
 
     # Create a directory to save model weights if it doesn't exist
     model_weights_dir = 'model_weights'
@@ -63,6 +53,7 @@ if __name__ == "__main__":
     parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
     parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
     parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
+    parser.add_argument('--pgt_coeff', type=float, default=None, help='Coefficient for PGT')
     args = parser.parse_args()
 
     # Set CUDA device (only needed for multi-gpu machines)