PGT Training now functioning. Run run_pgt_train.py to train the model.

This commit is contained in:
nielseni6 2024-10-16 19:26:41 -04:00
parent 2a95a652bd
commit 38fa59edf2
16 changed files with 1895 additions and 32 deletions

View File

@ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO
# from ultralytics import BaseTrainer # from ultralytics import BaseTrainer
# from ultralytics.engine.trainer import BaseTrainer # from ultralytics.engine.trainer import BaseTrainer
import os import os
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
# Set CUDA device (only needed for multi-gpu machines) # Set CUDA device (only needed for multi-gpu machines)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4" os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# model = YOLOv10() # model = YOLOv10()
# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt') # build from YAML and transfer weights
# model = YOLO() # model = YOLO()
# If you want to finetune the model with pretrained weights, you could load the # If you want to finetune the model with pretrained weights, you could load the
# pretrained weights like below # pretrained weights like below
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
# or # or
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
model = YOLOv10('yolov10n.pt') model = YOLOv10('yolov10n.pt', task='segment')
model.train(data='coco.yaml', args = dict(model='yolov10n.pt', data='coco128-seg.yaml')
trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) trainer = PGTSegmentationTrainer(overrides=args)
# Add return_images as input parameter trainer.train(
epochs=500, batch=16, imgsz=640, # debug=True,
debug=True, # If debug = True, the attributions will be saved in the figures folder # args = dict(pgt_coeff=0.1),
) )
# model.train(
# # data='coco.yaml',
# data='coco128-seg.yaml',
# trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
# # Add return_images as input parameter
# epochs=500, batch=16, imgsz=640,
# debug=True, # If debug = True, the attributions will be saved in the figures folder
# # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml',
# # overrides=dict(task="segment"),
# )
# Save the trained model # Save the trained model
model.save('yolov10_coco_trained.pt') model.save('yolov10_coco_trained.pt')

47
run_pgt_train.py Normal file
View File

@ -0,0 +1,47 @@
from ultralytics import YOLOv10, YOLO
# from ultralytics.engine.pgt_trainer import PGTTrainer
import os
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
import argparse
def main(args):
# model = YOLOv10()
# If you want to finetune the model with pretrained weights, you could load the
# pretrained weights like below
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
# or
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
model = YOLOv10('yolov10n.pt', task='segment')
args = dict(model='yolov10n.pt', data='coco.yaml',
epochs=args.epochs, batch=args.batch_size,
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
)
trainer = PGTSegmentationTrainer(overrides=args)
trainer.train(
# debug=True,
# args = dict(pgt_coeff=0.1), # Should add later to config
)
# Save the trained model
model.save('yolov10_coco_trained.pt')
# Evaluate the model on the validation set
results = model.val(data='coco.yaml')
# Print the evaluation results
print(results)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
args = parser.parse_args()
# Set CUDA device (only needed for multi-gpu machines)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
main(args)

View File

@ -656,7 +656,10 @@ class Model(nn.Module):
pass pass
self.trainer.hub_session = self.session # attach optional HUB session self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train(debug=debug) if debug:
self.trainer.train(debug=debug)
else:
self.trainer.train()
# Update model and cfg after training # Update model and cfg after training
if RANK in (-1, 0): if RANK in (-1, 0):
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last

View File

@ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import (
strip_optimizer, strip_optimizer,
) )
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
import matplotlib.path as matplotlib_path
class PGTTrainer: class PGTTrainer:
""" """
BaseTrainer. BaseTrainer.
@ -159,6 +161,7 @@ class PGTTrainer:
self.loss_names = ["Loss"] self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv" self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2] self.plot_idx = [0, 1, 2]
self.num = int(time.time())
# Callbacks # Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks() self.callbacks = _callbacks or callbacks.get_default_callbacks()
@ -328,7 +331,7 @@ class PGTTrainer:
if world_size > 1: if world_size > 1:
self._setup_ddp(world_size) self._setup_ddp(world_size)
self._setup_train(world_size) self._setup_train(world_size)
nb = len(self.train_loader) # number of batches nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1 last_opt_step = -1
@ -382,37 +385,88 @@ class PGTTrainer:
with torch.cuda.amp.autocast(self.amp): with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
(self.loss, self.loss_items), images = self.model(batch, return_images=True) (self.loss, self.loss_items), images = self.model(batch, return_images=True)
# smask = get_dist_reg(images, batch['masks'])
# grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
# grad = torch.abs(grad)
# self.args.pgt_coeff = 1.1
# plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff)
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
# self.loss += plaus_loss
debug_ = debug
if debug_ and (i % 25 == 0):
debug_ = False
# Create a tensor of zeros with the same size as images
mask = torch.zeros_like(images, dtype=torch.float32)
smask = get_dist_reg(images, batch['masks'])
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
grad = torch.abs(grad)
batch_size = images.shape[0]
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
if debug and (i % 250): # Iterate over each bounding box and set the corresponding pixels to 1
grad = torch.autograd.grad(self.loss, images, create_graph=True)[0] for irx, bboxes in enumerate(gt_bboxes):
for idx in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[idx]
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
mask[irx, :, y1:y2, x1:x2] = 1.0
save_imgs = True
if save_imgs:
# Convert tensors to numpy arrays # Convert tensors to numpy arrays
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
# Normalize grad for visualization # Normalize grad for visualization
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
for ix in range(images_np.shape[0]): for ix in range(images_np.shape[0]):
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) fig, ax = plt.subplots(1, 6, figsize=(30, 5))
ax[0].imshow(images_np[i]) ax[0].imshow(images_np[ix])
ax[0].set_title('Image') ax[0].set_title('Image')
ax[1].imshow(grad_np[i], cmap='jet') ax[1].imshow(grad_np[ix], cmap='jet')
ax[1].set_title('Gradient') ax[1].set_title('Gradient')
ax[2].imshow(images_np[i]) ax[2].imshow(images_np[ix])
ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5) ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
ax[2].set_title('Overlay') ax[2].set_title('Overlay')
ax[3].imshow(mask_np[ix], cmap='gray')
ax[3].set_title('Mask')
ax[4].imshow(seg_mask_np[ix], cmap='gray')
ax[4].set_title('Segmentation Mask')
# Plot image with bounding boxes
ax[5].imshow(images_np[ix])
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
ax[5].add_patch(rect)
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
ax[5].set_title('Bounding Boxes')
save_dir_attr = "figures/attributions" save_dir_attr = f"figures/attributions/run{self.num}"
if not os.path.exists(save_dir_attr): if not os.path.exists(save_dir_attr):
os.makedirs(save_dir_attr) os.makedirs(save_dir_attr)
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
plt.close(fig) plt.close(fig)
images = images.detach()
if RANK != -1:
self.loss *= world_size
self.tloss = ( if RANK != -1:
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items self.loss *= world_size
) self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
# Backward # Backward
self.scaler.scale(self.loss).backward() self.scaler.scale(self.loss).backward()

View File

@ -0,0 +1,345 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Check a model's accuracy on a test or val split of a dataset.
Usage:
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
Usage - formats:
$ yolo mode=val model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlpackage # CoreML (macOS-only)
yolov8n_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
yolov8n_ncnn_model # NCNN
"""
import json
import time
from pathlib import Path
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
class PGTValidator:
"""
BaseValidator.
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
names (dict): Class names.
seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initializes a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.pbar = pbar
self.stride = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
# @smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
gets priority).
"""
self.training = trainer is not None
augment = self.args.augment and (not self.training)
if self.training:
self.device = trainer.device
self.data = trainer.data
# self.args.half = self.device.type != "cpu" # force FP16 val during training
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
model.eval()
else:
callbacks.add_integration_callbacks(self)
model = AutoBackend(
weights=model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
)
# self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
self.data = check_det_dataset(self.args.data)
elif self.args.task == "classify":
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not pt:
self.args.rect = False
self.stride = model.stride # used in get_dataloader() for padding
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
self.run_callbacks("on_val_start")
dt = (
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
)
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# Preprocess
with dt[0]:
batch = self.preprocess(batch)
# Inference
with dt[1]:
preds = model(batch["img"].requires_grad_(True), augment=augment)
# Loss
with dt[2]:
if self.training:
self.loss += model.loss(batch, preds)[1]
# Postprocess
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks("on_val_batch_end")
stats = self.get_stats()
self.check_stats(stats)
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics()
if not (self.args.save_json and self.is_coco and len(self.jdict)):
self.print_results()
self.run_callbacks("on_val_end")
if self.training:
model.float()
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
stats['fitness'] = stats['metrics/mAP50-95(B)']
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
LOGGER.info(
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
% tuple(self.speed.values())
)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
"""
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
use_scipy (bool): Whether to use scipy for matching (more precise).
Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
"""
# Dx10 matrix, where D - detections, 10 - IoU thresholds
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
# LxD matrix where L - labels (rows), D - detections (columns)
correct_class = true_classes[:, None] == pred_classes
iou = iou * correct_class # zero out the wrong classes
iou = iou.cpu().numpy()
for i, threshold in enumerate(self.iouv.cpu().tolist()):
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
valid = cost_matrix[labels_idx, detections_idx] > 0
if valid.any():
correct[detections_idx[valid], i] = True
else:
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
matches = np.array(matches).T
if matches.shape[0]:
if matches.shape[0] > 1:
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
def add_callback(self, event: str, callback):
"""Appends the given callback."""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
"""Runs all callbacks associated with a specified event."""
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError("get_dataloader function not implemented for this validator")
def build_dataset(self, img_path):
"""Build dataset."""
raise NotImplementedError("build_dataset function not implemented in validator")
def preprocess(self, batch):
"""Preprocesses an input batch."""
return batch
def postprocess(self, preds):
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
return preds
def init_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""
pass
def update_metrics(self, preds, batch):
"""Updates metrics based on predictions and batch."""
pass
def finalize_metrics(self, *args, **kwargs):
"""Finalizes and returns all metrics."""
pass
def get_stats(self):
"""Returns statistics about the model's performance."""
return {}
def check_stats(self, stats):
"""Checks statistics."""
pass
def print_results(self):
"""Prints the results of the model's predictions."""
pass
def get_desc(self):
"""Get description of the YOLO model."""
pass
@property
def metric_keys(self):
"""Returns the metric keys used in YOLO training/validation."""
return []
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
"""Plots validation samples during training."""
pass
def plot_predictions(self, batch, preds, ni):
"""Plots YOLO model predictions on batch images."""
pass
def pred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""
pass
def eval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""
pass

View File

@ -4,5 +4,6 @@ from .predict import DetectionPredictor
from .pgt_train import PGTDetectionTrainer from .pgt_train import PGTDetectionTrainer
from .train import DetectionTrainer from .train import DetectionTrainer
from .val import DetectionValidator from .val import DetectionValidator
from .pgt_val import PGTDetectionValidator
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer" __all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator"

View File

@ -0,0 +1,300 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.engine.pgt_validator import PGTValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PGTDetectionValidator(PGTValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionValidator
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.is_coco = False
self.class_map = None
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.lb = [] # for autolabelling
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
for k in ["batch_idx", "cls", "bboxes"]:
batch[k] = batch[k].to(self.device)
if self.args.save_hybrid:
height, width = batch["img"].shape[2:]
nb = len(batch["img"])
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
self.lb = (
[
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
for i in range(nb)
]
if self.args.save_hybrid
else []
) # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
)
def _prepare_batch(self, si, batch):
"""Prepares a batch of images and annotations for validation."""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""
predn = pred.clone()
ops.scale_boxes(
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
) # native-space pred
return predn
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = self._prepare_pred(pred, pbatch)
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat["tp"] = self._process_batch(predn, bbox, cls)
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si])
if self.args.save_txt:
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
if len(stats) and stats["tp"].any():
self.metrics.process(**stats)
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
)
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(
batch["img"],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
}
)
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats

View File

@ -3,5 +3,6 @@
from .predict import SegmentationPredictor from .predict import SegmentationPredictor
from .train import SegmentationTrainer from .train import SegmentationTrainer
from .val import SegmentationValidator from .val import SegmentationValidator
from .pgt_train import PGTSegmentationTrainer
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer"

View File

@ -0,0 +1,73 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
else:
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
return YOLOv10PGTDetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
else:
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
return yolo.segment.SegmentationValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
masks=batch["masks"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png

View File

@ -1,5 +1,5 @@
from ultralytics.engine.model import Model from ultralytics.engine.model import Model
from ultralytics.nn.tasks import YOLOv10DetectionModel from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
from .val import YOLOv10DetectionValidator from .val import YOLOv10DetectionValidator
from .predict import YOLOv10DetectionPredictor from .predict import YOLOv10DetectionPredictor
from .train import YOLOv10DetectionTrainer from .train import YOLOv10DetectionTrainer

View File

@ -1,4 +1,4 @@
from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator
from ultralytics.utils import ops from ultralytics.utils import ops
import torch import torch
@ -7,6 +7,27 @@ class YOLOv10DetectionValidator(DetectionValidator):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.args.save_json |= self.is_coco self.args.save_json |= self.is_coco
def postprocess(self, preds):
if isinstance(preds, dict):
preds = preds["one2one"]
if isinstance(preds, (list, tuple)):
preds = preds[0]
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
if preds.shape[-1] == 6:
return preds
else:
preds = preds.transpose(-1, -2)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
class YOLOv10PGTDetectionValidator(PGTDetectionValidator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args.save_json |= self.is_coco
def postprocess(self, preds): def postprocess(self, preds):
if isinstance(preds, dict): if isinstance(preds, dict):
preds = preds["one2one"] preds = preds["one2one"]

View File

@ -57,7 +57,7 @@ from ultralytics.nn.modules import (
) )
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss
from ultralytics.utils.plotting import feature_visualization from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import ( from ultralytics.utils.torch_utils import (
fuse_conv_and_bn, fuse_conv_and_bn,
@ -651,6 +651,10 @@ class WorldModel(DetectionModel):
class YOLOv10DetectionModel(DetectionModel): class YOLOv10DetectionModel(DetectionModel):
def init_criterion(self): def init_criterion(self):
return v10DetectLoss(self) return v10DetectLoss(self)
class YOLOv10PGTDetectionModel(DetectionModel):
def init_criterion(self):
return v10PGTDetectLoss(self)
class Ensemble(nn.ModuleList): class Ensemble(nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""

View File

@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou from .metrics import bbox_iou, probiou
from .tal import bbox2dist from .tal import bbox2dist
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
class VarifocalLoss(nn.Module): class VarifocalLoss(nn.Module):
""" """
@ -725,3 +725,30 @@ class v10DetectLoss:
one2one = preds["one2one"] one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch) loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
class v10PGTDetectLoss:
def __init__(self, model):
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
batch['img'] = batch['img'].requires_grad_(True)
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
loss = loss_one2many[0] + loss_one2one[0]
smask = get_dist_reg(batch['img'], batch['masks'])
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
grad = torch.abs(grad)
pgt_coeff = 3.0
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
loss += plaus_loss
return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))

View File

@ -0,0 +1,818 @@
import torch
import numpy as np
# from plot_functs import *
from .plot_functs import normalize_tensor, overlay_mask, imshow
import math
import time
import matplotlib.path as mplPath
from matplotlib.path import Path
# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
from .metrics import bbox_iou
import torchvision.transforms as T
def plaus_loss_fn(grad, smask, pgt_coeff):
################## Compute the PGT Loss ##################
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(grad, smask)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
plaus_reg = (((1.0 + dist_reg) / 2.0))
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * pgt_coeff
return plaus_loss
def get_dist_reg(images, seg_mask):
seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
seg_mask[seg_mask > 0] = 1.0
smask = torch.zeros_like(seg_mask)
sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
for k_it, sigma in enumerate(sigmas):
# Apply Gaussian blur to the mask
kernel_size = int(sigma + 50)
if kernel_size % 2 == 0:
kernel_size += 1
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
if torch.max(seg_mask1) > 1.0:
seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
smask = torch.max(smask, seg_mask1)
return smask
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
"""
Compute the gradient of an image with respect to a given tensor.
Args:
img (torch.Tensor): The input image tensor.
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
Returns:
torch.Tensor: The computed attribution map.
"""
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
else:
grad_wrt_outputs = None
attribution_map = torch.autograd.grad(grad_wrt, img,
grad_outputs=grad_wrt_outputs,
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)[0]
if absolute:
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(attribution_map)
attmin = torch.min(attribution_map)
attmax = torch.max(attribution_map)
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
if keepmean:
attribution_map -= attribution_map.mean()
attribution_map += (attmean / (attmax - attmin))
return attribution_map
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
"""
Generate Gaussian noise based on the input image.
Args:
img (torch.Tensor): Input image.
grad_wrt: Gradient with respect to the input image.
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
Returns:
torch.Tensor: Generated Gaussian noise.
"""
gaussian_noise = torch.randn_like(img)
if absolute:
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(gaussian_noise)
attmin = torch.min(gaussian_noise)
attmax = torch.max(gaussian_noise)
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
if keepmean:
gaussian_noise -= gaussian_noise.mean()
gaussian_noise += (attmean / (attmax - attmin))
return gaussian_noise
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# # if imgs is None:
# # imgs = torch.zeros_like(attr)
# # with torch.no_grad():
# target_inds = targets_out[:, 0].int()
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
# co = xyxy_corners
# if corners:
# co = targets_out[:, 2:6].int()
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
# # rows = np.arange(co.shape[0])
# x1, x2 = co[:,1], co[:,3]
# y1, y2 = co[:,0], co[:,2]
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
coords_map = get_bbox_map(targets_out, attr)
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# with torch.no_grad():
# # att_select = attr[coords_map]
# att_select = attr * coords_map.to(torch.float32)
# att_total = attr
# IoU_num = torch.sum(att_select)
# IoU_denom = torch.sum(att_total)
# IoU_ = (IoU_num / IoU_denom)
# plaus_score = IoU_
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
return plaus_score
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# if imgs is None:
# imgs = torch.zeros_like(attr)
# with torch.no_grad():
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# att_select = attr[coords_map]
# with torch.no_grad():
# IoU_num = (torch.sum(attr[coords_map]))
# IoU_denom = torch.sum(attr)
# IoU_ = (IoU_num / (IoU_denom))
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
co = (xyxy_batch * num_pixels).int()
x1 = co[:,1] + 1
y1 = co[:,0] + 1
# with torch.no_grad():
attr_ = torch.sum(attr, 1, keepdim=True)
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
for ic in range(co.shape[0]):
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
x_0, y_0 = max_indices_2d(attr0[0])
x_1, y_1 = max_indices_2d(attr1[0])
x_2, y_2 = max_indices_2d(attr2[0])
x_3, y_3 = max_indices_2d(attr3[0])
y_1 += y1[ic]
x_2 += x1[ic]
x_3 += x1[ic]
y_3 += y1[ic]
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
if corners_attr is None:
corners_attr = max_corners
else:
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
# corners_attr[ic] = max_corners
# corners_attr = attr[:,0,:4,0]
corners_attr = corners_attr.view(-1, 4)
# corners_attr = torch.stack(corners_attr, dim=0)
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
plaus_score = IoU_.mean()
return plaus_score
def max_indices_2d(x_inp):
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
torch.max(x_inp,)
index = torch.argmax(x_inp)
x = index // x_inp.shape[1]
y = index % x_inp.shape[1]
# x, y = divmod(index.item(), x_inp.shape[1])
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
def point_in_polygon(poly, grid):
# t0 = time.time()
num_points = poly.shape[0]
j = num_points - 1
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
for i in range(num_points):
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
j = i
# t1 = time.time()
# print(f'point in polygon time: {t1-t0}')
return oddNodes
def point_in_polygon_gpu(poly, grid):
num_points = poly.shape[0]
i = torch.arange(num_points)
j = (i - 1) % num_points
# Expand dimensions
# t0 = time.time()
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
# t1 = time.time()
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
# t2 = time.time()
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
cond = (cond1 | cond2) & cond3
# t3 = time.time()
# efficiently perform xor using gpu and avoiding cpu as much as possible
c = []
while len(cond) > 1:
if len(cond) % 2 == 1: # odd number of elements
c.append(cond[-1])
cond = cond[:-1]
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
for c_ in c:
cond = torch.bitwise_xor(cond, c_)
oddNodes = cond
# t4 = time.time()
# for c in cond:
# oddNodes = oddNodes ^ c
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
# print(f'point in polygon time gpu: {t4-t0}')
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
return oddNodes
def bitmap_for_polygon(poly, h, w):
y = torch.arange(h).to(poly.device).float()
x = torch.arange(w).to(poly.device).float()
grid_y, grid_x = torch.meshgrid(y, x)
grid = torch.stack((grid_x, grid_y), dim=-1)
bitmap = point_in_polygon(poly, grid)
return bitmap.unsqueeze(0)
def corners_coords(center_xywh):
center_x, center_y, w, h = center_xywh
x = center_x - w/2
y = center_y - h/2
return torch.tensor([x, y, x+w, y+h])
def corners_coords_batch(center_xywh):
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
w, h = center_xywh[:,2], center_xywh[:,3]
x = center_x - w/2
y = center_y - h/2
return torch.stack([x, y, x+w, y+h], dim=1)
def normalize_batch(x):
"""
Normalize a batch of tensors along each channel.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Normalized tensor of the same shape as the input.
"""
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
for i in range(x.shape[0]):
mins[i] = x[i].min()
maxs[i] = x[i].max()
x_ = (x - mins) / (maxs - mins)
return x_
def get_detections(model_clone, img):
"""
Get detections from a model given an input image and targets.
Args:
model (nn.Module): The model to use for detection.
img (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The detected bounding boxes.
"""
model_clone.eval() # Set model to evaluation mode
# Run inference
with torch.no_grad():
det_out, out = model_clone(img)
# model_.train()
del img
return det_out, out
def get_labels(det_out, imgs, targets, opt):
###################### Get predicted labels ######################
nb, _, height, width = imgs.shape # batch size, channels, height, width
targets_ = targets.clone()
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
pred_labels = []
for si, pred in enumerate(o):
labels = targets_[targets_[:, 0] == si, 1:]
nl = len(labels)
predn = pred.clone()
# Get the indices that sort the values in column 5 in ascending order
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
# Apply the sorting indices to the tensor
sorted_pred = predn[sort_indices]
# Remove predictions with less than 0.1 confidence
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
sorted_pred = sorted_pred[:n_conf]
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
preds[:, 2:] /= gn.to(imgs.device) # from pixels
pred_labels.append(preds)
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
return pred_labels
##################################################################
from torchvision.utils import make_grid
def get_center_coords(attr):
img_tensor = img_tensor / img_tensor.max()
# Define a brightness threshold
threshold = 0.95
# Create a binary mask of the bright pixels
mask = img_tensor > threshold
# Get the coordinates of the bright pixels
y_coords, x_coords = torch.where(mask)
# Calculate the centroid of the bright pixels
centroid_x = x_coords.float().mean().item()
centroid_y = y_coords.float().mean().item()
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
return
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
"""
Compute the distance grids from each pixel to the target coordinates.
Args:
attr (torch.Tensor): Attribution maps.
targets (torch.Tensor): Target coordinates.
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
debug (bool, optional): Whether to visualize debug information. Defaults to False.
Returns:
torch.Tensor: Distance grids.
"""
# Assign the height and width of the input tensor to variables
height, width = attr.shape[-1], attr.shape[-2]
# attr = torch.abs(attr) # Take absolute values of gradients
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
# Create a grid of indices
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
idx_grid = torch.stack((xx, yy), dim=-1).float()
# Expand the grid to match the batch size
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
# Initialize a list to store the distance grids
dist_grids_ = [[]] * attr.shape[0]
# Loop over batches
for j in range(attr.shape[0]):
# Get the rows where the first column is the current unique value
rows = targets[targets[:, 0] == j]
if len(rows) != 0:
# Create a tensor for the target coordinates
xy = rows[:,2:4] # y, x
# Flip the x and y coordinates and scale them to the image size
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
# Compute the Euclidean distance from each pixel to the target coordinates
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
# Pick the closest distance to any target for each pixel
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
else:
# Set grid to zero if no targets are present
dist_grid = torch.zeros_like(attr[j])
dist_grids_[j] = dist_grid
# Convert the list of distance grids to a tensor for faster computation
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
if torch.isnan(dist_grids).any():
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
if debug:
for i in range(len(dist_grids)):
if ((i % 8) == 0):
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
imshow(grid_show, save_path='figs/dist_grids')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
imshow(img_overlay, save_path='figs/dist_grid_overlay')
weighted_attr = (dist_grids[i] * attr[i])
imshow(weighted_attr, save_path='figs/weighted_attr')
imshow(attr[i], save_path='figs/attr')
return dist_grids
def attr_reg(attribution_map, distance_map):
# dist_attr = distance_map * attribution_map
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
# del distance_map, attribution_map
return dist_attr
def get_bbox_map(targets_out, attr, corners=False):
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
bbox_map = coords_map.to(torch.float32)
return bbox_map
######################################## BCE #######################################
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
# if imgs is None:
# imgs = torch.zeros_like(attribution_map)
# Calculate Plausibility IoU with attribution maps
# attribution_map.retains_grad = True
if not only_loss:
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
else:
plaus_score = torch.tensor(0.0)
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
# Calculate distance regularization
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
# distance_map = torch.ones_like(attribution_map)
if opt.dist_x_bbox:
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
distance_map[bbox_map] = 0.0
# distance_map = distance_map * (1 - bbox_map)
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(attribution_map, distance_map)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
# / 2.5
if opt.bbox_coeff != 0.0:
bbox_map = get_bbox_map(targets, attribution_map)
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
bbox_reg = attr_bbox_pos - attr_bbox_neg
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
else:
bbox_reg = 0.0
bbox_map = get_bbox_map(targets, attribution_map)
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
# iou_loss = (1.0 - plaus_score)
if not opt.dist_reg_only:
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
plaus_reg = (plaus_score * opt.iou_coeff) + \
(((dist_reg_loss * opt.dist_coeff) + \
(bbox_reg * opt.bbox_coeff))\
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
# / (plaus_score) \
)
else:
plaus_reg = (((1.0 + dist_reg) / 2.0))
# plaus_reg = dist_reg
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
# plaus_loss = (plaus_reg) * opt.pgt_coeff
if only_loss:
return plaus_loss
if not debug:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
else:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
####################################################################################
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
####################################################################################
def generate_vanilla_grad(model, input_tensor, loss_func = None,
targets_list=None, targets=None, metric=None, out_num = 1,
n_max_labels=3, norm=True, abs=True, grayscale=True,
class_specific_attr = True, device='cpu'):
"""
Generate vanilla gradients for the given model and input tensor.
Args:
model (nn.Module): The model to generate gradients for.
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
targets_list (list, optional): The list of target tensors. Defaults to None.
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
device (str, optional): The device to use for computation. Defaults to 'cpu'.
Returns:
torch.Tensor: The generated vanilla gradients.
"""
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
train_mode = model.training
if not train_mode:
model.train()
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
model.zero_grad() # Zero gradients
inpt = input_tensor
# Forward pass
train_out = model(inpt) # training outputs (no inference outputs in train mode)
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
if class_specific_attr:
n_attr_list, index_classes = [], []
for i in range(len(input_tensor)):
if len(targets_list[i]) > n_max_labels:
targets_list[i] = targets_list[i][:n_max_labels]
if targets_list[i].numel() != 0:
# unique_classes = torch.unique(targets_list[i][:,1])
class_numbers = targets_list[i][:,1]
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
num_attrs = len(targets_list[i])
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
n_attr_list.append(num_attrs)
else:
index_classes.append([0, 1, 2, 3, 4])
n_attr_list.append(0)
targets_list_filled = [targ.clone().detach() for targ in targets_list]
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
max_labels = np.max(labels_len)
max_index = np.argmax(labels_len)
for i in range(len(targets_list)):
# targets_list_filled[i] = targets_list[i]
if len(targets_list_filled[i]) < max_labels:
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
else:
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
for i in range(len(targets_list_filled)-1,-1,-1):
if targets_list_filled[i].numel() == 0:
targets_list_filled.pop(i)
targets_list_filled = torch.cat(targets_list_filled)
n_img_attrs = len(input_tensor) if class_specific_attr else 1
n_img_attrs = 1 if loss_func else n_img_attrs
attrs_batch = []
for i_batch in range(n_img_attrs):
if loss_func and class_specific_attr:
i_batch = max_index
# inpt = input_tensor[i_batch].unsqueeze(0)
# ##################################################################
# model.zero_grad() # Zero gradients
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
# ##################################################################
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
attrs_img = []
for i_attr in range(n_label_attrs):
if loss_func is None:
grad_wrt = train_out[out_num]
if class_specific_attr:
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
grad_wrt_outputs = torch.ones_like(grad_wrt)
else:
# if class_specific_attr:
# targets = targets_list[:][i_attr]
# n_targets = len(targets_list[i_batch])
if class_specific_attr:
target_indiv = targets_list_filled[:,i_attr] # batch image input
else:
target_indiv = targets
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
try:
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
except:
target_indiv = target_indiv.to(device)
inpt = inpt.to(device)
for tro in train_out:
tro = tro.to(device)
print("Error in loss function, trying again with device specified")
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
grad_wrt = loss
grad_wrt_outputs = None
model.zero_grad() # Zero gradients
gradients = torch.autograd.grad(grad_wrt, inpt,
grad_outputs=grad_wrt_outputs,
retain_graph=True,
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)
# Convert gradients to numpy array and back to ensure full separation from graph
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if abs:
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
if norm:
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
attrs_img.append(attribution_map)
if len(attrs_img) == 0:
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
else:
attrs_batch.append(torch.stack(attrs_img).to(device))
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
out_attr = attrs_batch
# Set model back to original mode
if not train_mode:
model.eval()
return out_attr
class RVNonLinearFunc(torch.nn.Module):
"""
Custom Bayesian ReLU activation function for random variables.
Attributes:
None
"""
def __init__(self, func):
super(RVNonLinearFunc, self).__init__()
self.func = func
def forward(self, mu_in, Sigma_in):
"""
Forward pass of the Bayesian ReLU activation function.
Args:
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
representing the mean input to the ReLU activation function.
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
representing the covariance input to the ReLU activation function.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
including the mean of the output and the covariance of the output.
"""
# Collect stats
batch_size = mu_in.size(0)
# Mean
mu_out = self.func(mu_in)
# Compute the derivative of the ReLU activation function with respect to the input mean
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
# add an extra dimension to gradi at position 2 and 1
grad1 = gradi.unsqueeze(dim=2)
grad2 = gradi.unsqueeze(dim=1)
# compute the outer product of grad1 and grad2
outer_product = torch.bmm(grad1, grad2)
# element-wise multiply Sigma_in with the outer product
# and return the result
Sigma_out = torch.mul(Sigma_in, outer_product)
return mu_out, Sigma_out

View File

@ -0,0 +1,154 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
class Subplots:
def __init__(self, figsize = (40, 5)):
self.fig = plt.figure(figsize=figsize)
def plot_img_list(self, img_list, savedir='figs/test',
nrows = 1, rownum = 0,
hold = False, coltitles=[], rowtitle=''):
for i, img in enumerate(img_list):
try:
npimg = img.clone().detach().cpu().numpy()
except:
npimg = img
tpimg = np.transpose(npimg, (1, 2, 0))
lenrow = int((len(img_list)))
ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
if len(coltitles) > i:
ax.set_title(coltitles[i])
if i == 0:
ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords='axes fraction', textcoords='offset points',
size='large', ha='center', va='baseline')
# ax.set_ylabel(rowtitle, rotation=90)
ax.imshow(tpimg)
ax.axis('off')
if not hold:
self.fig.tight_layout()
plt.savefig(f'{savedir}.png')
plt.clf()
plt.close('all')
def VisualizeNumpyImageGrayscale(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
vmin = np.min(image_3d)
image_2d = image_3d - vmin
vmax = np.max(image_2d)
return (image_2d / vmax)
def normalize_numpy(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
vmin = np.min(image_3d)
image_2d = image_3d - vmin
vmax = np.max(image_2d)
return (image_2d / vmax)
# def normalize_tensor(image_3d):
# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
# """
# vmin = torch.min(image_3d)
# image_2d = image_3d - vmin
# vmax = torch.max(image_2d)
# return (image_2d / vmax)
def normalize_tensor(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
image_2d = (image_3d - torch.min(image_3d))
return (image_2d / torch.max(image_2d))
def format_img(img_):
np_img = img_.numpy()
tp_img = np.transpose(np_img, (1, 2, 0))
return tp_img
def imshow(img, save_path=None):
try:
npimg = img.clone().detach().cpu().numpy()
except:
npimg = img
tpimg = np.transpose(npimg, (1, 2, 0))
plt.imshow(tpimg)
# plt.axis('off')
plt.tight_layout()
if save_path != None:
plt.savefig(str(str(save_path) + ".png"))
#plt.show()a
def imshow_img(img, imsave_path):
# works for tensors and numpy arrays
try:
npimg = VisualizeNumpyImageGrayscale(img.numpy())
except:
npimg = VisualizeNumpyImageGrayscale(img)
npimg = np.transpose(npimg, (2, 0, 1))
imshow(npimg, save_path=imsave_path)
print("Saving image as ", imsave_path)
def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
model.train()
model.to(device)
img = img.to(device)
img.requires_grad_(True)
labels.to(device).requires_grad_(True)
model.requires_grad_(True)
cuda = device.type != 'cpu'
scaler = amp.GradScaler(enabled=cuda)
pred = model(img)
# out, train_out = model(img, augment=augment) # inference and training outputs
loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
# loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
# loss = torch.sum(loss).requires_grad_(True)
with torch.autograd.set_detect_anomaly(True):
scaler.scale(loss).backward(inputs=img)
# loss.backward()
# S_c = torch.max(pred[0].data, 0)[0]
Sc_dx = img.grad
model.eval()
Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
return Sc_dx
def calculate_snr(img, attr, dB=True):
try:
img_np = img.detach().cpu().numpy()
attr_np = attr.detach().cpu().numpy()
except:
img_np = img
attr_np = attr
# Calculate the signal power
signal_power = np.mean(img_np**2)
# Calculate the noise power
noise_power = np.mean(attr_np**2)
if dB == True:
# Calculate SNR in dB
snr = 10 * np.log10(signal_power / noise_power)
else:
# Calculate SNR
snr = signal_power / noise_power
return snr
def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
cmap = plt.get_cmap(colormap)
npmask = np.array(mask.clone().detach().cpu().squeeze(0))
# cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
return overlayed_tensor

View File

@ -717,7 +717,7 @@ def plot_images(
): ):
"""Plot image grid with labels.""" """Plot image grid with labels."""
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy() images = images.detach().cpu().float().numpy()
if isinstance(cls, torch.Tensor): if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy() cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor): if isinstance(bboxes, torch.Tensor):