PGT Training now functioning. Run run_pgt_train.py to train the model.

This commit is contained in:
nielseni6 2024-10-16 19:26:41 -04:00
parent 2a95a652bd
commit 38fa59edf2
16 changed files with 1895 additions and 32 deletions

View File

@ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO
# from ultralytics import BaseTrainer
# from ultralytics.engine.trainer import BaseTrainer
import os
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
# Set CUDA device (only needed for multi-gpu machines)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# model = YOLOv10()
# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt') # build from YAML and transfer weights
# model = YOLO()
# If you want to finetune the model with pretrained weights, you could load the
# pretrained weights like below
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
# or
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
model = YOLOv10('yolov10n.pt')
model = YOLOv10('yolov10n.pt', task='segment')
model.train(data='coco.yaml',
trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
# Add return_images as input parameter
epochs=500, batch=16, imgsz=640,
debug=True, # If debug = True, the attributions will be saved in the figures folder
)
args = dict(model='yolov10n.pt', data='coco128-seg.yaml')
trainer = PGTSegmentationTrainer(overrides=args)
trainer.train(
# debug=True,
# args = dict(pgt_coeff=0.1),
)
# model.train(
# # data='coco.yaml',
# data='coco128-seg.yaml',
# trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
# # Add return_images as input parameter
# epochs=500, batch=16, imgsz=640,
# debug=True, # If debug = True, the attributions will be saved in the figures folder
# # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml',
# # overrides=dict(task="segment"),
# )
# Save the trained model
model.save('yolov10_coco_trained.pt')

47
run_pgt_train.py Normal file
View File

@ -0,0 +1,47 @@
from ultralytics import YOLOv10, YOLO
# from ultralytics.engine.pgt_trainer import PGTTrainer
import os
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
import argparse
def main(args):
# model = YOLOv10()
# If you want to finetune the model with pretrained weights, you could load the
# pretrained weights like below
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
# or
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
model = YOLOv10('yolov10n.pt', task='segment')
args = dict(model='yolov10n.pt', data='coco.yaml',
epochs=args.epochs, batch=args.batch_size,
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
)
trainer = PGTSegmentationTrainer(overrides=args)
trainer.train(
# debug=True,
# args = dict(pgt_coeff=0.1), # Should add later to config
)
# Save the trained model
model.save('yolov10_coco_trained.pt')
# Evaluate the model on the validation set
results = model.val(data='coco.yaml')
# Print the evaluation results
print(results)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
args = parser.parse_args()
# Set CUDA device (only needed for multi-gpu machines)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
main(args)

View File

@ -656,7 +656,10 @@ class Model(nn.Module):
pass
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train(debug=debug)
if debug:
self.trainer.train(debug=debug)
else:
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last

View File

@ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import (
strip_optimizer,
)
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
import matplotlib.path as matplotlib_path
class PGTTrainer:
"""
BaseTrainer.
@ -159,6 +161,7 @@ class PGTTrainer:
self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2]
self.num = int(time.time())
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
@ -383,36 +386,87 @@ class PGTTrainer:
batch = self.preprocess_batch(batch)
(self.loss, self.loss_items), images = self.model(batch, return_images=True)
if debug and (i % 250):
grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
# smask = get_dist_reg(images, batch['masks'])
# grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
# grad = torch.abs(grad)
# self.args.pgt_coeff = 1.1
# plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff)
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
# self.loss += plaus_loss
debug_ = debug
if debug_ and (i % 25 == 0):
debug_ = False
# Create a tensor of zeros with the same size as images
mask = torch.zeros_like(images, dtype=torch.float32)
smask = get_dist_reg(images, batch['masks'])
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
grad = torch.abs(grad)
batch_size = images.shape[0]
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# Iterate over each bounding box and set the corresponding pixels to 1
for irx, bboxes in enumerate(gt_bboxes):
for idx in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[idx]
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
mask[irx, :, y1:y2, x1:x2] = 1.0
save_imgs = True
if save_imgs:
# Convert tensors to numpy arrays
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
# Normalize grad for visualization
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
for ix in range(images_np.shape[0]):
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(images_np[i])
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
ax[0].imshow(images_np[ix])
ax[0].set_title('Image')
ax[1].imshow(grad_np[i], cmap='jet')
ax[1].imshow(grad_np[ix], cmap='jet')
ax[1].set_title('Gradient')
ax[2].imshow(images_np[i])
ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
ax[2].imshow(images_np[ix])
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
ax[2].set_title('Overlay')
ax[3].imshow(mask_np[ix], cmap='gray')
ax[3].set_title('Mask')
ax[4].imshow(seg_mask_np[ix], cmap='gray')
ax[4].set_title('Segmentation Mask')
save_dir_attr = "figures/attributions"
# Plot image with bounding boxes
ax[5].imshow(images_np[ix])
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
ax[5].add_patch(rect)
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
ax[5].set_title('Bounding Boxes')
save_dir_attr = f"figures/attributions/run{self.num}"
if not os.path.exists(save_dir_attr):
os.makedirs(save_dir_attr)
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
plt.close(fig)
images = images.detach()
if RANK != -1:
self.loss *= world_size
self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
if RANK != -1:
self.loss *= world_size
self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
# Backward
self.scaler.scale(self.loss).backward()

View File

@ -0,0 +1,345 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Check a model's accuracy on a test or val split of a dataset.
Usage:
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
Usage - formats:
$ yolo mode=val model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlpackage # CoreML (macOS-only)
yolov8n_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
yolov8n_ncnn_model # NCNN
"""
import json
import time
from pathlib import Path
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
class PGTValidator:
"""
BaseValidator.
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
names (dict): Class names.
seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initializes a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.pbar = pbar
self.stride = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
# @smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
gets priority).
"""
self.training = trainer is not None
augment = self.args.augment and (not self.training)
if self.training:
self.device = trainer.device
self.data = trainer.data
# self.args.half = self.device.type != "cpu" # force FP16 val during training
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
model.eval()
else:
callbacks.add_integration_callbacks(self)
model = AutoBackend(
weights=model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
)
# self.model = model
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
self.data = check_det_dataset(self.args.data)
elif self.args.task == "classify":
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type in ("cpu", "mps"):
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not pt:
self.args.rect = False
self.stride = model.stride # used in get_dataloader() for padding
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
model.eval()
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
self.run_callbacks("on_val_start")
dt = (
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
)
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# Preprocess
with dt[0]:
batch = self.preprocess(batch)
# Inference
with dt[1]:
preds = model(batch["img"].requires_grad_(True), augment=augment)
# Loss
with dt[2]:
if self.training:
self.loss += model.loss(batch, preds)[1]
# Postprocess
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks("on_val_batch_end")
stats = self.get_stats()
self.check_stats(stats)
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics()
if not (self.args.save_json and self.is_coco and len(self.jdict)):
self.print_results()
self.run_callbacks("on_val_end")
if self.training:
model.float()
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
stats['fitness'] = stats['metrics/mAP50-95(B)']
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
LOGGER.info(
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
% tuple(self.speed.values())
)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
"""
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
use_scipy (bool): Whether to use scipy for matching (more precise).
Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
"""
# Dx10 matrix, where D - detections, 10 - IoU thresholds
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
# LxD matrix where L - labels (rows), D - detections (columns)
correct_class = true_classes[:, None] == pred_classes
iou = iou * correct_class # zero out the wrong classes
iou = iou.cpu().numpy()
for i, threshold in enumerate(self.iouv.cpu().tolist()):
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
valid = cost_matrix[labels_idx, detections_idx] > 0
if valid.any():
correct[detections_idx[valid], i] = True
else:
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
matches = np.array(matches).T
if matches.shape[0]:
if matches.shape[0] > 1:
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
def add_callback(self, event: str, callback):
"""Appends the given callback."""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
"""Runs all callbacks associated with a specified event."""
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError("get_dataloader function not implemented for this validator")
def build_dataset(self, img_path):
"""Build dataset."""
raise NotImplementedError("build_dataset function not implemented in validator")
def preprocess(self, batch):
"""Preprocesses an input batch."""
return batch
def postprocess(self, preds):
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
return preds
def init_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""
pass
def update_metrics(self, preds, batch):
"""Updates metrics based on predictions and batch."""
pass
def finalize_metrics(self, *args, **kwargs):
"""Finalizes and returns all metrics."""
pass
def get_stats(self):
"""Returns statistics about the model's performance."""
return {}
def check_stats(self, stats):
"""Checks statistics."""
pass
def print_results(self):
"""Prints the results of the model's predictions."""
pass
def get_desc(self):
"""Get description of the YOLO model."""
pass
@property
def metric_keys(self):
"""Returns the metric keys used in YOLO training/validation."""
return []
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
"""Plots validation samples during training."""
pass
def plot_predictions(self, batch, preds, ni):
"""Plots YOLO model predictions on batch images."""
pass
def pred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""
pass
def eval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""
pass

View File

@ -4,5 +4,6 @@ from .predict import DetectionPredictor
from .pgt_train import PGTDetectionTrainer
from .train import DetectionTrainer
from .val import DetectionValidator
from .pgt_val import PGTDetectionValidator
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator"

View File

@ -0,0 +1,300 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.engine.pgt_validator import PGTValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PGTDetectionValidator(PGTValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionValidator
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.is_coco = False
self.class_map = None
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.lb = [] # for autolabelling
def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
for k in ["batch_idx", "cls", "bboxes"]:
batch[k] = batch[k].to(self.device)
if self.args.save_hybrid:
height, width = batch["img"].shape[2:]
nb = len(batch["img"])
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
self.lb = (
[
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
for i in range(nb)
]
if self.args.save_hybrid
else []
) # for autolabelling
return batch
def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
)
def _prepare_batch(self, si, batch):
"""Prepares a batch of images and annotations for validation."""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""
predn = pred.clone()
ops.scale_boxes(
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
) # native-space pred
return predn
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = self._prepare_pred(pred, pbatch)
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat["tp"] = self._process_batch(predn, bbox, cls)
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si])
if self.args.save_txt:
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
if len(stats) and stats["tp"].any():
self.metrics.process(**stats)
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
return self.metrics.results_dict
def print_results(self):
"""Prints training/validation set metrics per class."""
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
)
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
Returns:
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
def get_dataloader(self, dataset_path, batch_size):
"""Construct and return dataloader."""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(
batch["img"],
*output_to_target(preds, max_det=self.args.max_det),
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
}
)
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats

View File

@ -3,5 +3,6 @@
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
from .pgt_train import PGTSegmentationTrainer
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer"

View File

@ -0,0 +1,73 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
else:
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
return YOLOv10PGTDetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
else:
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
return yolo.segment.SegmentationValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
masks=batch["masks"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png

View File

@ -1,5 +1,5 @@
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import YOLOv10DetectionModel
from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
from .val import YOLOv10DetectionValidator
from .predict import YOLOv10DetectionPredictor
from .train import YOLOv10DetectionTrainer

View File

@ -1,4 +1,4 @@
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator
from ultralytics.utils import ops
import torch
@ -22,3 +22,24 @@ class YOLOv10DetectionValidator(DetectionValidator):
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
class YOLOv10PGTDetectionValidator(PGTDetectionValidator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args.save_json |= self.is_coco
def postprocess(self, preds):
if isinstance(preds, dict):
preds = preds["one2one"]
if isinstance(preds, (list, tuple)):
preds = preds[0]
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
if preds.shape[-1] == 6:
return preds
else:
preds = preds.transpose(-1, -2)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

View File

@ -57,7 +57,7 @@ from ultralytics.nn.modules import (
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (
fuse_conv_and_bn,
@ -652,6 +652,10 @@ class YOLOv10DetectionModel(DetectionModel):
def init_criterion(self):
return v10DetectLoss(self)
class YOLOv10PGTDetectionModel(DetectionModel):
def init_criterion(self):
return v10PGTDetectLoss(self)
class Ensemble(nn.ModuleList):
"""Ensemble of models."""

View File

@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
class VarifocalLoss(nn.Module):
"""
@ -725,3 +725,30 @@ class v10DetectLoss:
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
class v10PGTDetectLoss:
def __init__(self, model):
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
batch['img'] = batch['img'].requires_grad_(True)
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
loss = loss_one2many[0] + loss_one2one[0]
smask = get_dist_reg(batch['img'], batch['masks'])
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
grad = torch.abs(grad)
pgt_coeff = 3.0
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
loss += plaus_loss
return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))

View File

@ -0,0 +1,818 @@
import torch
import numpy as np
# from plot_functs import *
from .plot_functs import normalize_tensor, overlay_mask, imshow
import math
import time
import matplotlib.path as mplPath
from matplotlib.path import Path
# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
from .metrics import bbox_iou
import torchvision.transforms as T
def plaus_loss_fn(grad, smask, pgt_coeff):
################## Compute the PGT Loss ##################
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(grad, smask)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
plaus_reg = (((1.0 + dist_reg) / 2.0))
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * pgt_coeff
return plaus_loss
def get_dist_reg(images, seg_mask):
seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
seg_mask[seg_mask > 0] = 1.0
smask = torch.zeros_like(seg_mask)
sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
for k_it, sigma in enumerate(sigmas):
# Apply Gaussian blur to the mask
kernel_size = int(sigma + 50)
if kernel_size % 2 == 0:
kernel_size += 1
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
if torch.max(seg_mask1) > 1.0:
seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
smask = torch.max(smask, seg_mask1)
return smask
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
"""
Compute the gradient of an image with respect to a given tensor.
Args:
img (torch.Tensor): The input image tensor.
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
Returns:
torch.Tensor: The computed attribution map.
"""
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
else:
grad_wrt_outputs = None
attribution_map = torch.autograd.grad(grad_wrt, img,
grad_outputs=grad_wrt_outputs,
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)[0]
if absolute:
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(attribution_map)
attmin = torch.min(attribution_map)
attmax = torch.max(attribution_map)
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
if keepmean:
attribution_map -= attribution_map.mean()
attribution_map += (attmean / (attmax - attmin))
return attribution_map
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
"""
Generate Gaussian noise based on the input image.
Args:
img (torch.Tensor): Input image.
grad_wrt: Gradient with respect to the input image.
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
Returns:
torch.Tensor: Generated Gaussian noise.
"""
gaussian_noise = torch.randn_like(img)
if absolute:
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(gaussian_noise)
attmin = torch.min(gaussian_noise)
attmax = torch.max(gaussian_noise)
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
if keepmean:
gaussian_noise -= gaussian_noise.mean()
gaussian_noise += (attmean / (attmax - attmin))
return gaussian_noise
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# # if imgs is None:
# # imgs = torch.zeros_like(attr)
# # with torch.no_grad():
# target_inds = targets_out[:, 0].int()
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
# co = xyxy_corners
# if corners:
# co = targets_out[:, 2:6].int()
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
# # rows = np.arange(co.shape[0])
# x1, x2 = co[:,1], co[:,3]
# y1, y2 = co[:,0], co[:,2]
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
coords_map = get_bbox_map(targets_out, attr)
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# with torch.no_grad():
# # att_select = attr[coords_map]
# att_select = attr * coords_map.to(torch.float32)
# att_total = attr
# IoU_num = torch.sum(att_select)
# IoU_denom = torch.sum(att_total)
# IoU_ = (IoU_num / IoU_denom)
# plaus_score = IoU_
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
return plaus_score
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# if imgs is None:
# imgs = torch.zeros_like(attr)
# with torch.no_grad():
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# att_select = attr[coords_map]
# with torch.no_grad():
# IoU_num = (torch.sum(attr[coords_map]))
# IoU_denom = torch.sum(attr)
# IoU_ = (IoU_num / (IoU_denom))
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
co = (xyxy_batch * num_pixels).int()
x1 = co[:,1] + 1
y1 = co[:,0] + 1
# with torch.no_grad():
attr_ = torch.sum(attr, 1, keepdim=True)
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
for ic in range(co.shape[0]):
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
x_0, y_0 = max_indices_2d(attr0[0])
x_1, y_1 = max_indices_2d(attr1[0])
x_2, y_2 = max_indices_2d(attr2[0])
x_3, y_3 = max_indices_2d(attr3[0])
y_1 += y1[ic]
x_2 += x1[ic]
x_3 += x1[ic]
y_3 += y1[ic]
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
if corners_attr is None:
corners_attr = max_corners
else:
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
# corners_attr[ic] = max_corners
# corners_attr = attr[:,0,:4,0]
corners_attr = corners_attr.view(-1, 4)
# corners_attr = torch.stack(corners_attr, dim=0)
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
plaus_score = IoU_.mean()
return plaus_score
def max_indices_2d(x_inp):
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
torch.max(x_inp,)
index = torch.argmax(x_inp)
x = index // x_inp.shape[1]
y = index % x_inp.shape[1]
# x, y = divmod(index.item(), x_inp.shape[1])
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
def point_in_polygon(poly, grid):
# t0 = time.time()
num_points = poly.shape[0]
j = num_points - 1
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
for i in range(num_points):
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
j = i
# t1 = time.time()
# print(f'point in polygon time: {t1-t0}')
return oddNodes
def point_in_polygon_gpu(poly, grid):
num_points = poly.shape[0]
i = torch.arange(num_points)
j = (i - 1) % num_points
# Expand dimensions
# t0 = time.time()
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
# t1 = time.time()
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
# t2 = time.time()
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
cond = (cond1 | cond2) & cond3
# t3 = time.time()
# efficiently perform xor using gpu and avoiding cpu as much as possible
c = []
while len(cond) > 1:
if len(cond) % 2 == 1: # odd number of elements
c.append(cond[-1])
cond = cond[:-1]
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
for c_ in c:
cond = torch.bitwise_xor(cond, c_)
oddNodes = cond
# t4 = time.time()
# for c in cond:
# oddNodes = oddNodes ^ c
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
# print(f'point in polygon time gpu: {t4-t0}')
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
return oddNodes
def bitmap_for_polygon(poly, h, w):
y = torch.arange(h).to(poly.device).float()
x = torch.arange(w).to(poly.device).float()
grid_y, grid_x = torch.meshgrid(y, x)
grid = torch.stack((grid_x, grid_y), dim=-1)
bitmap = point_in_polygon(poly, grid)
return bitmap.unsqueeze(0)
def corners_coords(center_xywh):
center_x, center_y, w, h = center_xywh
x = center_x - w/2
y = center_y - h/2
return torch.tensor([x, y, x+w, y+h])
def corners_coords_batch(center_xywh):
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
w, h = center_xywh[:,2], center_xywh[:,3]
x = center_x - w/2
y = center_y - h/2
return torch.stack([x, y, x+w, y+h], dim=1)
def normalize_batch(x):
"""
Normalize a batch of tensors along each channel.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Normalized tensor of the same shape as the input.
"""
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
for i in range(x.shape[0]):
mins[i] = x[i].min()
maxs[i] = x[i].max()
x_ = (x - mins) / (maxs - mins)
return x_
def get_detections(model_clone, img):
"""
Get detections from a model given an input image and targets.
Args:
model (nn.Module): The model to use for detection.
img (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The detected bounding boxes.
"""
model_clone.eval() # Set model to evaluation mode
# Run inference
with torch.no_grad():
det_out, out = model_clone(img)
# model_.train()
del img
return det_out, out
def get_labels(det_out, imgs, targets, opt):
###################### Get predicted labels ######################
nb, _, height, width = imgs.shape # batch size, channels, height, width
targets_ = targets.clone()
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
pred_labels = []
for si, pred in enumerate(o):
labels = targets_[targets_[:, 0] == si, 1:]
nl = len(labels)
predn = pred.clone()
# Get the indices that sort the values in column 5 in ascending order
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
# Apply the sorting indices to the tensor
sorted_pred = predn[sort_indices]
# Remove predictions with less than 0.1 confidence
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
sorted_pred = sorted_pred[:n_conf]
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
preds[:, 2:] /= gn.to(imgs.device) # from pixels
pred_labels.append(preds)
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
return pred_labels
##################################################################
from torchvision.utils import make_grid
def get_center_coords(attr):
img_tensor = img_tensor / img_tensor.max()
# Define a brightness threshold
threshold = 0.95
# Create a binary mask of the bright pixels
mask = img_tensor > threshold
# Get the coordinates of the bright pixels
y_coords, x_coords = torch.where(mask)
# Calculate the centroid of the bright pixels
centroid_x = x_coords.float().mean().item()
centroid_y = y_coords.float().mean().item()
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
return
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
"""
Compute the distance grids from each pixel to the target coordinates.
Args:
attr (torch.Tensor): Attribution maps.
targets (torch.Tensor): Target coordinates.
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
debug (bool, optional): Whether to visualize debug information. Defaults to False.
Returns:
torch.Tensor: Distance grids.
"""
# Assign the height and width of the input tensor to variables
height, width = attr.shape[-1], attr.shape[-2]
# attr = torch.abs(attr) # Take absolute values of gradients
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
# Create a grid of indices
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
idx_grid = torch.stack((xx, yy), dim=-1).float()
# Expand the grid to match the batch size
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
# Initialize a list to store the distance grids
dist_grids_ = [[]] * attr.shape[0]
# Loop over batches
for j in range(attr.shape[0]):
# Get the rows where the first column is the current unique value
rows = targets[targets[:, 0] == j]
if len(rows) != 0:
# Create a tensor for the target coordinates
xy = rows[:,2:4] # y, x
# Flip the x and y coordinates and scale them to the image size
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
# Compute the Euclidean distance from each pixel to the target coordinates
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
# Pick the closest distance to any target for each pixel
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
else:
# Set grid to zero if no targets are present
dist_grid = torch.zeros_like(attr[j])
dist_grids_[j] = dist_grid
# Convert the list of distance grids to a tensor for faster computation
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
if torch.isnan(dist_grids).any():
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
if debug:
for i in range(len(dist_grids)):
if ((i % 8) == 0):
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
imshow(grid_show, save_path='figs/dist_grids')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
imshow(img_overlay, save_path='figs/dist_grid_overlay')
weighted_attr = (dist_grids[i] * attr[i])
imshow(weighted_attr, save_path='figs/weighted_attr')
imshow(attr[i], save_path='figs/attr')
return dist_grids
def attr_reg(attribution_map, distance_map):
# dist_attr = distance_map * attribution_map
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
# del distance_map, attribution_map
return dist_attr
def get_bbox_map(targets_out, attr, corners=False):
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
bbox_map = coords_map.to(torch.float32)
return bbox_map
######################################## BCE #######################################
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
# if imgs is None:
# imgs = torch.zeros_like(attribution_map)
# Calculate Plausibility IoU with attribution maps
# attribution_map.retains_grad = True
if not only_loss:
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
else:
plaus_score = torch.tensor(0.0)
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
# Calculate distance regularization
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
# distance_map = torch.ones_like(attribution_map)
if opt.dist_x_bbox:
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
distance_map[bbox_map] = 0.0
# distance_map = distance_map * (1 - bbox_map)
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(attribution_map, distance_map)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
# / 2.5
if opt.bbox_coeff != 0.0:
bbox_map = get_bbox_map(targets, attribution_map)
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
bbox_reg = attr_bbox_pos - attr_bbox_neg
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
else:
bbox_reg = 0.0
bbox_map = get_bbox_map(targets, attribution_map)
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
# iou_loss = (1.0 - plaus_score)
if not opt.dist_reg_only:
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
plaus_reg = (plaus_score * opt.iou_coeff) + \
(((dist_reg_loss * opt.dist_coeff) + \
(bbox_reg * opt.bbox_coeff))\
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
# / (plaus_score) \
)
else:
plaus_reg = (((1.0 + dist_reg) / 2.0))
# plaus_reg = dist_reg
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
# plaus_loss = (plaus_reg) * opt.pgt_coeff
if only_loss:
return plaus_loss
if not debug:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
else:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
####################################################################################
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
####################################################################################
def generate_vanilla_grad(model, input_tensor, loss_func = None,
targets_list=None, targets=None, metric=None, out_num = 1,
n_max_labels=3, norm=True, abs=True, grayscale=True,
class_specific_attr = True, device='cpu'):
"""
Generate vanilla gradients for the given model and input tensor.
Args:
model (nn.Module): The model to generate gradients for.
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
targets_list (list, optional): The list of target tensors. Defaults to None.
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
device (str, optional): The device to use for computation. Defaults to 'cpu'.
Returns:
torch.Tensor: The generated vanilla gradients.
"""
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
train_mode = model.training
if not train_mode:
model.train()
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
model.zero_grad() # Zero gradients
inpt = input_tensor
# Forward pass
train_out = model(inpt) # training outputs (no inference outputs in train mode)
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
if class_specific_attr:
n_attr_list, index_classes = [], []
for i in range(len(input_tensor)):
if len(targets_list[i]) > n_max_labels:
targets_list[i] = targets_list[i][:n_max_labels]
if targets_list[i].numel() != 0:
# unique_classes = torch.unique(targets_list[i][:,1])
class_numbers = targets_list[i][:,1]
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
num_attrs = len(targets_list[i])
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
n_attr_list.append(num_attrs)
else:
index_classes.append([0, 1, 2, 3, 4])
n_attr_list.append(0)
targets_list_filled = [targ.clone().detach() for targ in targets_list]
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
max_labels = np.max(labels_len)
max_index = np.argmax(labels_len)
for i in range(len(targets_list)):
# targets_list_filled[i] = targets_list[i]
if len(targets_list_filled[i]) < max_labels:
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
else:
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
for i in range(len(targets_list_filled)-1,-1,-1):
if targets_list_filled[i].numel() == 0:
targets_list_filled.pop(i)
targets_list_filled = torch.cat(targets_list_filled)
n_img_attrs = len(input_tensor) if class_specific_attr else 1
n_img_attrs = 1 if loss_func else n_img_attrs
attrs_batch = []
for i_batch in range(n_img_attrs):
if loss_func and class_specific_attr:
i_batch = max_index
# inpt = input_tensor[i_batch].unsqueeze(0)
# ##################################################################
# model.zero_grad() # Zero gradients
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
# ##################################################################
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
attrs_img = []
for i_attr in range(n_label_attrs):
if loss_func is None:
grad_wrt = train_out[out_num]
if class_specific_attr:
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
grad_wrt_outputs = torch.ones_like(grad_wrt)
else:
# if class_specific_attr:
# targets = targets_list[:][i_attr]
# n_targets = len(targets_list[i_batch])
if class_specific_attr:
target_indiv = targets_list_filled[:,i_attr] # batch image input
else:
target_indiv = targets
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
try:
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
except:
target_indiv = target_indiv.to(device)
inpt = inpt.to(device)
for tro in train_out:
tro = tro.to(device)
print("Error in loss function, trying again with device specified")
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
grad_wrt = loss
grad_wrt_outputs = None
model.zero_grad() # Zero gradients
gradients = torch.autograd.grad(grad_wrt, inpt,
grad_outputs=grad_wrt_outputs,
retain_graph=True,
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)
# Convert gradients to numpy array and back to ensure full separation from graph
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if abs:
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
if norm:
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
attrs_img.append(attribution_map)
if len(attrs_img) == 0:
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
else:
attrs_batch.append(torch.stack(attrs_img).to(device))
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
out_attr = attrs_batch
# Set model back to original mode
if not train_mode:
model.eval()
return out_attr
class RVNonLinearFunc(torch.nn.Module):
"""
Custom Bayesian ReLU activation function for random variables.
Attributes:
None
"""
def __init__(self, func):
super(RVNonLinearFunc, self).__init__()
self.func = func
def forward(self, mu_in, Sigma_in):
"""
Forward pass of the Bayesian ReLU activation function.
Args:
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
representing the mean input to the ReLU activation function.
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
representing the covariance input to the ReLU activation function.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
including the mean of the output and the covariance of the output.
"""
# Collect stats
batch_size = mu_in.size(0)
# Mean
mu_out = self.func(mu_in)
# Compute the derivative of the ReLU activation function with respect to the input mean
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
# add an extra dimension to gradi at position 2 and 1
grad1 = gradi.unsqueeze(dim=2)
grad2 = gradi.unsqueeze(dim=1)
# compute the outer product of grad1 and grad2
outer_product = torch.bmm(grad1, grad2)
# element-wise multiply Sigma_in with the outer product
# and return the result
Sigma_out = torch.mul(Sigma_in, outer_product)
return mu_out, Sigma_out

View File

@ -0,0 +1,154 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
class Subplots:
def __init__(self, figsize = (40, 5)):
self.fig = plt.figure(figsize=figsize)
def plot_img_list(self, img_list, savedir='figs/test',
nrows = 1, rownum = 0,
hold = False, coltitles=[], rowtitle=''):
for i, img in enumerate(img_list):
try:
npimg = img.clone().detach().cpu().numpy()
except:
npimg = img
tpimg = np.transpose(npimg, (1, 2, 0))
lenrow = int((len(img_list)))
ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
if len(coltitles) > i:
ax.set_title(coltitles[i])
if i == 0:
ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords='axes fraction', textcoords='offset points',
size='large', ha='center', va='baseline')
# ax.set_ylabel(rowtitle, rotation=90)
ax.imshow(tpimg)
ax.axis('off')
if not hold:
self.fig.tight_layout()
plt.savefig(f'{savedir}.png')
plt.clf()
plt.close('all')
def VisualizeNumpyImageGrayscale(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
vmin = np.min(image_3d)
image_2d = image_3d - vmin
vmax = np.max(image_2d)
return (image_2d / vmax)
def normalize_numpy(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
vmin = np.min(image_3d)
image_2d = image_3d - vmin
vmax = np.max(image_2d)
return (image_2d / vmax)
# def normalize_tensor(image_3d):
# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
# """
# vmin = torch.min(image_3d)
# image_2d = image_3d - vmin
# vmax = torch.max(image_2d)
# return (image_2d / vmax)
def normalize_tensor(image_3d):
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
"""
image_2d = (image_3d - torch.min(image_3d))
return (image_2d / torch.max(image_2d))
def format_img(img_):
np_img = img_.numpy()
tp_img = np.transpose(np_img, (1, 2, 0))
return tp_img
def imshow(img, save_path=None):
try:
npimg = img.clone().detach().cpu().numpy()
except:
npimg = img
tpimg = np.transpose(npimg, (1, 2, 0))
plt.imshow(tpimg)
# plt.axis('off')
plt.tight_layout()
if save_path != None:
plt.savefig(str(str(save_path) + ".png"))
#plt.show()a
def imshow_img(img, imsave_path):
# works for tensors and numpy arrays
try:
npimg = VisualizeNumpyImageGrayscale(img.numpy())
except:
npimg = VisualizeNumpyImageGrayscale(img)
npimg = np.transpose(npimg, (2, 0, 1))
imshow(npimg, save_path=imsave_path)
print("Saving image as ", imsave_path)
def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
model.train()
model.to(device)
img = img.to(device)
img.requires_grad_(True)
labels.to(device).requires_grad_(True)
model.requires_grad_(True)
cuda = device.type != 'cpu'
scaler = amp.GradScaler(enabled=cuda)
pred = model(img)
# out, train_out = model(img, augment=augment) # inference and training outputs
loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
# loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
# loss = torch.sum(loss).requires_grad_(True)
with torch.autograd.set_detect_anomaly(True):
scaler.scale(loss).backward(inputs=img)
# loss.backward()
# S_c = torch.max(pred[0].data, 0)[0]
Sc_dx = img.grad
model.eval()
Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
return Sc_dx
def calculate_snr(img, attr, dB=True):
try:
img_np = img.detach().cpu().numpy()
attr_np = attr.detach().cpu().numpy()
except:
img_np = img
attr_np = attr
# Calculate the signal power
signal_power = np.mean(img_np**2)
# Calculate the noise power
noise_power = np.mean(attr_np**2)
if dB == True:
# Calculate SNR in dB
snr = 10 * np.log10(signal_power / noise_power)
else:
# Calculate SNR
snr = signal_power / noise_power
return snr
def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
cmap = plt.get_cmap(colormap)
npmask = np.array(mask.clone().detach().cpu().squeeze(0))
# cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
return overlayed_tensor

View File

@ -717,7 +717,7 @@ def plot_images(
):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
images = images.detach().cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):