mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
PGT Training now functioning. Run run_pgt_train.py to train the model.
This commit is contained in:
parent
2a95a652bd
commit
38fa59edf2
@ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO
|
|||||||
# from ultralytics import BaseTrainer
|
# from ultralytics import BaseTrainer
|
||||||
# from ultralytics.engine.trainer import BaseTrainer
|
# from ultralytics.engine.trainer import BaseTrainer
|
||||||
import os
|
import os
|
||||||
|
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||||
|
|
||||||
|
|
||||||
# Set CUDA device (only needed for multi-gpu machines)
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
|
||||||
# model = YOLOv10()
|
# model = YOLOv10()
|
||||||
|
# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt') # build from YAML and transfer weights
|
||||||
|
|
||||||
# model = YOLO()
|
# model = YOLO()
|
||||||
# If you want to finetune the model with pretrained weights, you could load the
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
# pretrained weights like below
|
# pretrained weights like below
|
||||||
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
# or
|
# or
|
||||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
model = YOLOv10('yolov10n.pt')
|
model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
|
|
||||||
model.train(data='coco.yaml',
|
args = dict(model='yolov10n.pt', data='coco128-seg.yaml')
|
||||||
trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
|
trainer = PGTSegmentationTrainer(overrides=args)
|
||||||
# Add return_images as input parameter
|
trainer.train(
|
||||||
epochs=500, batch=16, imgsz=640,
|
# debug=True,
|
||||||
debug=True, # If debug = True, the attributions will be saved in the figures folder
|
# args = dict(pgt_coeff=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# model.train(
|
||||||
|
# # data='coco.yaml',
|
||||||
|
# data='coco128-seg.yaml',
|
||||||
|
# trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
|
||||||
|
# # Add return_images as input parameter
|
||||||
|
# epochs=500, batch=16, imgsz=640,
|
||||||
|
# debug=True, # If debug = True, the attributions will be saved in the figures folder
|
||||||
|
# # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml',
|
||||||
|
# # overrides=dict(task="segment"),
|
||||||
|
# )
|
||||||
|
|
||||||
# Save the trained model
|
# Save the trained model
|
||||||
model.save('yolov10_coco_trained.pt')
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
47
run_pgt_train.py
Normal file
47
run_pgt_train.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
import os
|
||||||
|
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# model = YOLOv10()
|
||||||
|
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
|
|
||||||
|
args = dict(model='yolov10n.pt', data='coco.yaml',
|
||||||
|
epochs=args.epochs, batch=args.batch_size,
|
||||||
|
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||||
|
)
|
||||||
|
trainer = PGTSegmentationTrainer(overrides=args)
|
||||||
|
trainer.train(
|
||||||
|
# debug=True,
|
||||||
|
# args = dict(pgt_coeff=0.1), # Should add later to config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
||||||
|
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||||
|
main(args)
|
@ -656,7 +656,10 @@ class Model(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
self.trainer.train(debug=debug)
|
if debug:
|
||||||
|
self.trainer.train(debug=debug)
|
||||||
|
else:
|
||||||
|
self.trainer.train()
|
||||||
# Update model and cfg after training
|
# Update model and cfg after training
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||||
|
@ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import (
|
|||||||
strip_optimizer,
|
strip_optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ultralytics.utils.loss import v8DetectionLoss
|
||||||
|
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
|
||||||
|
import matplotlib.path as matplotlib_path
|
||||||
class PGTTrainer:
|
class PGTTrainer:
|
||||||
"""
|
"""
|
||||||
BaseTrainer.
|
BaseTrainer.
|
||||||
@ -159,6 +161,7 @@ class PGTTrainer:
|
|||||||
self.loss_names = ["Loss"]
|
self.loss_names = ["Loss"]
|
||||||
self.csv = self.save_dir / "results.csv"
|
self.csv = self.save_dir / "results.csv"
|
||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
self.num = int(time.time())
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
@ -383,36 +386,87 @@ class PGTTrainer:
|
|||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
(self.loss, self.loss_items), images = self.model(batch, return_images=True)
|
(self.loss, self.loss_items), images = self.model(batch, return_images=True)
|
||||||
|
|
||||||
if debug and (i % 250):
|
# smask = get_dist_reg(images, batch['masks'])
|
||||||
grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
|
|
||||||
|
# grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
|
||||||
|
# grad = torch.abs(grad)
|
||||||
|
|
||||||
|
# self.args.pgt_coeff = 1.1
|
||||||
|
# plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff)
|
||||||
|
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||||
|
# self.loss += plaus_loss
|
||||||
|
|
||||||
|
debug_ = debug
|
||||||
|
if debug_ and (i % 25 == 0):
|
||||||
|
debug_ = False
|
||||||
|
# Create a tensor of zeros with the same size as images
|
||||||
|
mask = torch.zeros_like(images, dtype=torch.float32)
|
||||||
|
smask = get_dist_reg(images, batch['masks'])
|
||||||
|
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
|
||||||
|
grad = torch.abs(grad)
|
||||||
|
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
|
||||||
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||||
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
|
|
||||||
|
# Iterate over each bounding box and set the corresponding pixels to 1
|
||||||
|
for irx, bboxes in enumerate(gt_bboxes):
|
||||||
|
for idx in range(len(bboxes)):
|
||||||
|
x1, y1, x2, y2 = bboxes[idx]
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
mask[irx, :, y1:y2, x1:x2] = 1.0
|
||||||
|
|
||||||
|
save_imgs = True
|
||||||
|
if save_imgs:
|
||||||
# Convert tensors to numpy arrays
|
# Convert tensors to numpy arrays
|
||||||
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
# Normalize grad for visualization
|
# Normalize grad for visualization
|
||||||
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||||
|
|
||||||
for ix in range(images_np.shape[0]):
|
for ix in range(images_np.shape[0]):
|
||||||
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
||||||
ax[0].imshow(images_np[i])
|
ax[0].imshow(images_np[ix])
|
||||||
ax[0].set_title('Image')
|
ax[0].set_title('Image')
|
||||||
ax[1].imshow(grad_np[i], cmap='jet')
|
ax[1].imshow(grad_np[ix], cmap='jet')
|
||||||
ax[1].set_title('Gradient')
|
ax[1].set_title('Gradient')
|
||||||
ax[2].imshow(images_np[i])
|
ax[2].imshow(images_np[ix])
|
||||||
ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
|
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
||||||
ax[2].set_title('Overlay')
|
ax[2].set_title('Overlay')
|
||||||
|
ax[3].imshow(mask_np[ix], cmap='gray')
|
||||||
|
ax[3].set_title('Mask')
|
||||||
|
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
||||||
|
ax[4].set_title('Segmentation Mask')
|
||||||
|
|
||||||
save_dir_attr = "figures/attributions"
|
# Plot image with bounding boxes
|
||||||
|
ax[5].imshow(images_np[ix])
|
||||||
|
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
||||||
|
ax[5].add_patch(rect)
|
||||||
|
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
||||||
|
ax[5].set_title('Bounding Boxes')
|
||||||
|
|
||||||
|
save_dir_attr = f"figures/attributions/run{self.num}"
|
||||||
if not os.path.exists(save_dir_attr):
|
if not os.path.exists(save_dir_attr):
|
||||||
os.makedirs(save_dir_attr)
|
os.makedirs(save_dir_attr)
|
||||||
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
|
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
images = images.detach()
|
||||||
|
|
||||||
if RANK != -1:
|
|
||||||
self.loss *= world_size
|
if RANK != -1:
|
||||||
self.tloss = (
|
self.loss *= world_size
|
||||||
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
self.tloss = (
|
||||||
)
|
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||||
|
)
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
self.scaler.scale(self.loss).backward()
|
self.scaler.scale(self.loss).backward()
|
||||||
|
345
ultralytics/engine/pgt_validator.py
Normal file
345
ultralytics/engine/pgt_validator.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Check a model's accuracy on a test or val split of a dataset.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
||||||
|
|
||||||
|
Usage - formats:
|
||||||
|
$ yolo mode=val model=yolov8n.pt # PyTorch
|
||||||
|
yolov8n.torchscript # TorchScript
|
||||||
|
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||||
|
yolov8n_openvino_model # OpenVINO
|
||||||
|
yolov8n.engine # TensorRT
|
||||||
|
yolov8n.mlpackage # CoreML (macOS-only)
|
||||||
|
yolov8n_saved_model # TensorFlow SavedModel
|
||||||
|
yolov8n.pb # TensorFlow GraphDef
|
||||||
|
yolov8n.tflite # TensorFlow Lite
|
||||||
|
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||||
|
yolov8n_paddle_model # PaddlePaddle
|
||||||
|
yolov8n_ncnn_model # NCNN
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
|
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
||||||
|
from ultralytics.utils.checks import check_imgsz
|
||||||
|
from ultralytics.utils.ops import Profile
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||||
|
|
||||||
|
|
||||||
|
class PGTValidator:
|
||||||
|
"""
|
||||||
|
BaseValidator.
|
||||||
|
|
||||||
|
A base class for creating validators.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
|
dataloader (DataLoader): Dataloader to use for validation.
|
||||||
|
pbar (tqdm): Progress bar to update during validation.
|
||||||
|
model (nn.Module): Model to validate.
|
||||||
|
data (dict): Data dictionary.
|
||||||
|
device (torch.device): Device to use for validation.
|
||||||
|
batch_i (int): Current batch index.
|
||||||
|
training (bool): Whether the model is in training mode.
|
||||||
|
names (dict): Class names.
|
||||||
|
seen: Records the number of images seen so far during validation.
|
||||||
|
stats: Placeholder for statistics during validation.
|
||||||
|
confusion_matrix: Placeholder for a confusion matrix.
|
||||||
|
nc: Number of classes.
|
||||||
|
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||||
|
jdict (dict): Dictionary to store JSON validation results.
|
||||||
|
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||||
|
batch processing times in milliseconds.
|
||||||
|
save_dir (Path): Directory to save results.
|
||||||
|
plots (dict): Dictionary to store plots for visualization.
|
||||||
|
callbacks (dict): Dictionary to store various callback functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
|
"""
|
||||||
|
Initializes a BaseValidator instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||||
|
save_dir (Path, optional): Directory to save results.
|
||||||
|
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||||
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
|
_callbacks (dict): Dictionary to store various callback functions.
|
||||||
|
"""
|
||||||
|
self.args = get_cfg(overrides=args)
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.pbar = pbar
|
||||||
|
self.stride = None
|
||||||
|
self.data = None
|
||||||
|
self.device = None
|
||||||
|
self.batch_i = None
|
||||||
|
self.training = True
|
||||||
|
self.names = None
|
||||||
|
self.seen = None
|
||||||
|
self.stats = None
|
||||||
|
self.confusion_matrix = None
|
||||||
|
self.nc = None
|
||||||
|
self.iouv = None
|
||||||
|
self.jdict = None
|
||||||
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||||
|
|
||||||
|
self.save_dir = save_dir or get_save_dir(self.args)
|
||||||
|
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
if self.args.conf is None:
|
||||||
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
||||||
|
|
||||||
|
self.plots = {}
|
||||||
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
|
||||||
|
# @smart_inference_mode()
|
||||||
|
def __call__(self, trainer=None, model=None):
|
||||||
|
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
||||||
|
gets priority).
|
||||||
|
"""
|
||||||
|
self.training = trainer is not None
|
||||||
|
augment = self.args.augment and (not self.training)
|
||||||
|
if self.training:
|
||||||
|
self.device = trainer.device
|
||||||
|
self.data = trainer.data
|
||||||
|
# self.args.half = self.device.type != "cpu" # force FP16 val during training
|
||||||
|
model = trainer.ema.ema or trainer.model
|
||||||
|
model = model.half() if self.args.half else model.float()
|
||||||
|
# self.model = model
|
||||||
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||||
|
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||||
|
model.eval()
|
||||||
|
else:
|
||||||
|
callbacks.add_integration_callbacks(self)
|
||||||
|
model = AutoBackend(
|
||||||
|
weights=model or self.args.model,
|
||||||
|
device=select_device(self.args.device, self.args.batch),
|
||||||
|
dnn=self.args.dnn,
|
||||||
|
data=self.args.data,
|
||||||
|
fp16=self.args.half,
|
||||||
|
)
|
||||||
|
# self.model = model
|
||||||
|
self.device = model.device # update device
|
||||||
|
self.args.half = model.fp16 # update half
|
||||||
|
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||||
|
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||||
|
if engine:
|
||||||
|
self.args.batch = model.batch_size
|
||||||
|
elif not pt and not jit:
|
||||||
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
|
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||||
|
|
||||||
|
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||||
|
self.data = check_det_dataset(self.args.data)
|
||||||
|
elif self.args.task == "classify":
|
||||||
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
|
if self.device.type in ("cpu", "mps"):
|
||||||
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
|
if not pt:
|
||||||
|
self.args.rect = False
|
||||||
|
self.stride = model.stride # used in get_dataloader() for padding
|
||||||
|
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||||
|
|
||||||
|
self.run_callbacks("on_val_start")
|
||||||
|
dt = (
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
)
|
||||||
|
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
||||||
|
self.init_metrics(de_parallel(model))
|
||||||
|
self.jdict = [] # empty before each val
|
||||||
|
for batch_i, batch in enumerate(bar):
|
||||||
|
self.run_callbacks("on_val_batch_start")
|
||||||
|
self.batch_i = batch_i
|
||||||
|
# Preprocess
|
||||||
|
with dt[0]:
|
||||||
|
batch = self.preprocess(batch)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with dt[1]:
|
||||||
|
preds = model(batch["img"].requires_grad_(True), augment=augment)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
with dt[2]:
|
||||||
|
if self.training:
|
||||||
|
self.loss += model.loss(batch, preds)[1]
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
with dt[3]:
|
||||||
|
preds = self.postprocess(preds)
|
||||||
|
|
||||||
|
self.update_metrics(preds, batch)
|
||||||
|
if self.args.plots and batch_i < 3:
|
||||||
|
self.plot_val_samples(batch, batch_i)
|
||||||
|
self.plot_predictions(batch, preds, batch_i)
|
||||||
|
|
||||||
|
self.run_callbacks("on_val_batch_end")
|
||||||
|
stats = self.get_stats()
|
||||||
|
self.check_stats(stats)
|
||||||
|
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
||||||
|
self.finalize_metrics()
|
||||||
|
if not (self.args.save_json and self.is_coco and len(self.jdict)):
|
||||||
|
self.print_results()
|
||||||
|
self.run_callbacks("on_val_end")
|
||||||
|
if self.training:
|
||||||
|
model.float()
|
||||||
|
if self.args.save_json and self.jdict:
|
||||||
|
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||||
|
LOGGER.info(f"Saving {f.name}...")
|
||||||
|
json.dump(self.jdict, f) # flatten and save
|
||||||
|
stats = self.eval_json(stats) # update stats
|
||||||
|
stats['fitness'] = stats['metrics/mAP50-95(B)']
|
||||||
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||||
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
||||||
|
% tuple(self.speed.values())
|
||||||
|
)
|
||||||
|
if self.args.save_json and self.jdict:
|
||||||
|
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||||
|
LOGGER.info(f"Saving {f.name}...")
|
||||||
|
json.dump(self.jdict, f) # flatten and save
|
||||||
|
stats = self.eval_json(stats) # update stats
|
||||||
|
if self.args.plots or self.args.save_json:
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
|
||||||
|
"""
|
||||||
|
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
||||||
|
true_classes (torch.Tensor): Target class indices of shape(M,).
|
||||||
|
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
|
||||||
|
use_scipy (bool): Whether to use scipy for matching (more precise).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
||||||
|
"""
|
||||||
|
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
||||||
|
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||||
|
# LxD matrix where L - labels (rows), D - detections (columns)
|
||||||
|
correct_class = true_classes[:, None] == pred_classes
|
||||||
|
iou = iou * correct_class # zero out the wrong classes
|
||||||
|
iou = iou.cpu().numpy()
|
||||||
|
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
||||||
|
if use_scipy:
|
||||||
|
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
||||||
|
import scipy # scope import to avoid importing for all commands
|
||||||
|
|
||||||
|
cost_matrix = iou * (iou >= threshold)
|
||||||
|
if cost_matrix.any():
|
||||||
|
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
||||||
|
valid = cost_matrix[labels_idx, detections_idx] > 0
|
||||||
|
if valid.any():
|
||||||
|
correct[detections_idx[valid], i] = True
|
||||||
|
else:
|
||||||
|
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
||||||
|
matches = np.array(matches).T
|
||||||
|
if matches.shape[0]:
|
||||||
|
if matches.shape[0] > 1:
|
||||||
|
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""Appends the given callback."""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
"""Runs all callbacks associated with a specified event."""
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
"""Get data loader from dataset path and batch size."""
|
||||||
|
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||||
|
|
||||||
|
def build_dataset(self, img_path):
|
||||||
|
"""Build dataset."""
|
||||||
|
raise NotImplementedError("build_dataset function not implemented in validator")
|
||||||
|
|
||||||
|
def preprocess(self, batch):
|
||||||
|
"""Preprocesses an input batch."""
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
"""Initialize performance metrics for the YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
"""Updates metrics based on predictions and batch."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
|
"""Finalizes and returns all metrics."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Returns statistics about the model's performance."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_stats(self, stats):
|
||||||
|
"""Checks statistics."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
"""Prints the results of the model's predictions."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
"""Get description of the YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric_keys(self):
|
||||||
|
"""Returns the metric keys used in YOLO training/validation."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def on_plot(self, name, data=None):
|
||||||
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
|
# TODO: may need to put these following functions into callback
|
||||||
|
def plot_val_samples(self, batch, ni):
|
||||||
|
"""Plots validation samples during training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def plot_predictions(self, batch, preds, ni):
|
||||||
|
"""Plots YOLO model predictions on batch images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pred_to_json(self, preds, batch):
|
||||||
|
"""Convert predictions to JSON format."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def eval_json(self, stats):
|
||||||
|
"""Evaluate and return JSON format of prediction statistics."""
|
||||||
|
pass
|
@ -4,5 +4,6 @@ from .predict import DetectionPredictor
|
|||||||
from .pgt_train import PGTDetectionTrainer
|
from .pgt_train import PGTDetectionTrainer
|
||||||
from .train import DetectionTrainer
|
from .train import DetectionTrainer
|
||||||
from .val import DetectionValidator
|
from .val import DetectionValidator
|
||||||
|
from .pgt_val import PGTDetectionValidator
|
||||||
|
|
||||||
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
|
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator"
|
||||||
|
300
ultralytics/models/yolo/detect/pgt_val.py
Normal file
300
ultralytics/models/yolo/detect/pgt_val.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
||||||
|
from ultralytics.engine.validator import BaseValidator
|
||||||
|
from ultralytics.engine.pgt_validator import PGTValidator
|
||||||
|
from ultralytics.utils import LOGGER, ops
|
||||||
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||||
|
from ultralytics.utils.plotting import output_to_target, plot_images
|
||||||
|
|
||||||
|
|
||||||
|
class PGTDetectionValidator(PGTValidator):
|
||||||
|
"""
|
||||||
|
A class extending the BaseValidator class for validation based on a detection model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
|
|
||||||
|
args = dict(model='yolov8n.pt', data='coco8.yaml')
|
||||||
|
validator = DetectionValidator(args=args)
|
||||||
|
validator()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
|
"""Initialize detection model with necessary variables and settings."""
|
||||||
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
|
self.nt_per_class = None
|
||||||
|
self.is_coco = False
|
||||||
|
self.class_map = None
|
||||||
|
self.args.task = "detect"
|
||||||
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
||||||
|
self.niou = self.iouv.numel()
|
||||||
|
self.lb = [] # for autolabelling
|
||||||
|
|
||||||
|
def preprocess(self, batch):
|
||||||
|
"""Preprocesses batch of images for YOLO training."""
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||||
|
for k in ["batch_idx", "cls", "bboxes"]:
|
||||||
|
batch[k] = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.args.save_hybrid:
|
||||||
|
height, width = batch["img"].shape[2:]
|
||||||
|
nb = len(batch["img"])
|
||||||
|
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
|
||||||
|
self.lb = (
|
||||||
|
[
|
||||||
|
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
|
||||||
|
for i in range(nb)
|
||||||
|
]
|
||||||
|
if self.args.save_hybrid
|
||||||
|
else []
|
||||||
|
) # for autolabelling
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
"""Initialize evaluation metrics for YOLO."""
|
||||||
|
val = self.data.get(self.args.split, "") # validation path
|
||||||
|
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||||
|
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
|
self.args.save_json |= self.is_coco # run on final val if training COCO
|
||||||
|
self.names = model.names
|
||||||
|
self.nc = len(model.names)
|
||||||
|
self.metrics.names = self.names
|
||||||
|
self.metrics.plot = self.args.plots
|
||||||
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
|
||||||
|
self.seen = 0
|
||||||
|
self.jdict = []
|
||||||
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||||
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
|
self.args.conf,
|
||||||
|
self.args.iou,
|
||||||
|
labels=self.lb,
|
||||||
|
multi_label=True,
|
||||||
|
agnostic=self.args.single_cls,
|
||||||
|
max_det=self.args.max_det,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_batch(self, si, batch):
|
||||||
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
|
idx = batch["batch_idx"] == si
|
||||||
|
cls = batch["cls"][idx].squeeze(-1)
|
||||||
|
bbox = batch["bboxes"][idx]
|
||||||
|
ori_shape = batch["ori_shape"][si]
|
||||||
|
imgsz = batch["img"].shape[2:]
|
||||||
|
ratio_pad = batch["ratio_pad"][si]
|
||||||
|
if len(cls):
|
||||||
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
||||||
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
||||||
|
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
|
||||||
|
|
||||||
|
def _prepare_pred(self, pred, pbatch):
|
||||||
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
|
predn = pred.clone()
|
||||||
|
ops.scale_boxes(
|
||||||
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
||||||
|
) # native-space pred
|
||||||
|
return predn
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
"""Metrics."""
|
||||||
|
for si, pred in enumerate(preds):
|
||||||
|
self.seen += 1
|
||||||
|
npr = len(pred)
|
||||||
|
stat = dict(
|
||||||
|
conf=torch.zeros(0, device=self.device),
|
||||||
|
pred_cls=torch.zeros(0, device=self.device),
|
||||||
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
|
)
|
||||||
|
pbatch = self._prepare_batch(si, batch)
|
||||||
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
||||||
|
nl = len(cls)
|
||||||
|
stat["target_cls"] = cls
|
||||||
|
if npr == 0:
|
||||||
|
if nl:
|
||||||
|
for k in self.stats.keys():
|
||||||
|
self.stats[k].append(stat[k])
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
if self.args.single_cls:
|
||||||
|
pred[:, 5] = 0
|
||||||
|
predn = self._prepare_pred(pred, pbatch)
|
||||||
|
stat["conf"] = predn[:, 4]
|
||||||
|
stat["pred_cls"] = predn[:, 5]
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if nl:
|
||||||
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||||
|
for k in self.stats.keys():
|
||||||
|
self.stats[k].append(stat[k])
|
||||||
|
|
||||||
|
# Save
|
||||||
|
if self.args.save_json:
|
||||||
|
self.pred_to_json(predn, batch["im_file"][si])
|
||||||
|
if self.args.save_txt:
|
||||||
|
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
|
||||||
|
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
|
||||||
|
|
||||||
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
|
"""Set final values for metrics speed and confusion matrix."""
|
||||||
|
self.metrics.speed = self.speed
|
||||||
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Returns metrics statistics and results dictionary."""
|
||||||
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
||||||
|
if len(stats) and stats["tp"].any():
|
||||||
|
self.metrics.process(**stats)
|
||||||
|
self.nt_per_class = np.bincount(
|
||||||
|
stats["target_cls"].astype(int), minlength=self.nc
|
||||||
|
) # number of targets per class
|
||||||
|
return self.metrics.results_dict
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
"""Prints training/validation set metrics per class."""
|
||||||
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
||||||
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
|
if self.nt_per_class.sum() == 0:
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
|
||||||
|
|
||||||
|
# Print results per class
|
||||||
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||||
|
for i, c in enumerate(self.metrics.ap_class_index):
|
||||||
|
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||||
|
|
||||||
|
if self.args.plots:
|
||||||
|
for normalize in True, False:
|
||||||
|
self.confusion_matrix.plot(
|
||||||
|
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||||
|
"""
|
||||||
|
Return correct prediction matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
||||||
|
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
||||||
|
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
||||||
|
Each label is of the format: class, x1, y1, x2, y2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
||||||
|
"""
|
||||||
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||||
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="val", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
"""Construct and return dataloader."""
|
||||||
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
||||||
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
||||||
|
|
||||||
|
def plot_val_samples(self, batch, ni):
|
||||||
|
"""Plot validation image samples."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
batch["batch_idx"],
|
||||||
|
batch["cls"].squeeze(-1),
|
||||||
|
batch["bboxes"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
|
names=self.names,
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_predictions(self, batch, preds, ni):
|
||||||
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
*output_to_target(preds, max_det=self.args.max_det),
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
|
names=self.names,
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|
||||||
|
def save_one_txt(self, predn, save_conf, shape, file):
|
||||||
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||||
|
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
|
for *xyxy, conf, cls in predn.tolist():
|
||||||
|
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||||
|
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
||||||
|
with open(file, "a") as f:
|
||||||
|
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
|
def pred_to_json(self, predn, filename):
|
||||||
|
"""Serialize YOLO predictions to COCO json format."""
|
||||||
|
stem = Path(filename).stem
|
||||||
|
image_id = int(stem) if stem.isnumeric() else stem
|
||||||
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||||
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||||
|
for p, b in zip(predn.tolist(), box.tolist()):
|
||||||
|
self.jdict.append(
|
||||||
|
{
|
||||||
|
"image_id": image_id,
|
||||||
|
"category_id": self.class_map[int(p[5])],
|
||||||
|
"bbox": [round(x, 3) for x in b],
|
||||||
|
"score": round(p[4], 5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def eval_json(self, stats):
|
||||||
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||||
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
|
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||||
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
|
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||||
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
|
check_requirements("pycocotools>=2.0.6")
|
||||||
|
from pycocotools.coco import COCO # noqa
|
||||||
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
|
for x in anno_json, pred_json:
|
||||||
|
assert x.is_file(), f"{x} file not found"
|
||||||
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
|
eval = COCOeval(anno, pred, "bbox")
|
||||||
|
if self.is_coco:
|
||||||
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||||
|
eval.evaluate()
|
||||||
|
eval.accumulate()
|
||||||
|
eval.summarize()
|
||||||
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||||
|
return stats
|
@ -3,5 +3,6 @@
|
|||||||
from .predict import SegmentationPredictor
|
from .predict import SegmentationPredictor
|
||||||
from .train import SegmentationTrainer
|
from .train import SegmentationTrainer
|
||||||
from .val import SegmentationValidator
|
from .val import SegmentationValidator
|
||||||
|
from .pgt_train import PGTSegmentationTrainer
|
||||||
|
|
||||||
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer"
|
||||||
|
73
ultralytics/models/yolo/segment/pgt_train.py
Normal file
73
ultralytics/models/yolo/segment/pgt_train.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
|
||||||
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||||
|
from ultralytics.utils.plotting import plot_images, plot_results
|
||||||
|
from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
|
||||||
|
from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
|
||||||
|
|
||||||
|
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||||
|
|
||||||
|
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
|
||||||
|
trainer = SegmentationTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initialize a SegmentationTrainer object with given arguments."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
overrides["task"] = "segment"
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return SegmentationModel initialized with specified config and weights."""
|
||||||
|
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
|
||||||
|
model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
else:
|
||||||
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||||
|
|
||||||
|
if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']:
|
||||||
|
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
|
||||||
|
return YOLOv10PGTDetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
||||||
|
return yolo.segment.SegmentationValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
batch["batch_idx"],
|
||||||
|
batch["cls"].squeeze(-1),
|
||||||
|
batch["bboxes"],
|
||||||
|
masks=batch["masks"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plots training/val metrics."""
|
||||||
|
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
@ -1,5 +1,5 @@
|
|||||||
from ultralytics.engine.model import Model
|
from ultralytics.engine.model import Model
|
||||||
from ultralytics.nn.tasks import YOLOv10DetectionModel
|
from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
|
||||||
from .val import YOLOv10DetectionValidator
|
from .val import YOLOv10DetectionValidator
|
||||||
from .predict import YOLOv10DetectionPredictor
|
from .predict import YOLOv10DetectionPredictor
|
||||||
from .train import YOLOv10DetectionTrainer
|
from .train import YOLOv10DetectionTrainer
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator
|
||||||
from ultralytics.utils import ops
|
from ultralytics.utils import ops
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -22,3 +22,24 @@ class YOLOv10DetectionValidator(DetectionValidator):
|
|||||||
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
||||||
bboxes = ops.xywh2xyxy(boxes)
|
bboxes = ops.xywh2xyxy(boxes)
|
||||||
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionValidator(PGTDetectionValidator):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.args.save_json |= self.is_coco
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
if isinstance(preds, dict):
|
||||||
|
preds = preds["one2one"]
|
||||||
|
|
||||||
|
if isinstance(preds, (list, tuple)):
|
||||||
|
preds = preds[0]
|
||||||
|
|
||||||
|
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
|
||||||
|
if preds.shape[-1] == 6:
|
||||||
|
return preds
|
||||||
|
else:
|
||||||
|
preds = preds.transpose(-1, -2)
|
||||||
|
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
||||||
|
bboxes = ops.xywh2xyxy(boxes)
|
||||||
|
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
@ -57,7 +57,7 @@ from ultralytics.nn.modules import (
|
|||||||
)
|
)
|
||||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
|
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss
|
||||||
from ultralytics.utils.plotting import feature_visualization
|
from ultralytics.utils.plotting import feature_visualization
|
||||||
from ultralytics.utils.torch_utils import (
|
from ultralytics.utils.torch_utils import (
|
||||||
fuse_conv_and_bn,
|
fuse_conv_and_bn,
|
||||||
@ -652,6 +652,10 @@ class YOLOv10DetectionModel(DetectionModel):
|
|||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
return v10DetectLoss(self)
|
return v10DetectLoss(self)
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionModel(DetectionModel):
|
||||||
|
def init_criterion(self):
|
||||||
|
return v10PGTDetectLoss(self)
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
class Ensemble(nn.ModuleList):
|
||||||
"""Ensemble of models."""
|
"""Ensemble of models."""
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|||||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||||
from .metrics import bbox_iou, probiou
|
from .metrics import bbox_iou, probiou
|
||||||
from .tal import bbox2dist
|
from .tal import bbox2dist
|
||||||
|
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
|
||||||
|
|
||||||
class VarifocalLoss(nn.Module):
|
class VarifocalLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -725,3 +725,30 @@ class v10DetectLoss:
|
|||||||
one2one = preds["one2one"]
|
one2one = preds["one2one"]
|
||||||
loss_one2one = self.one2one(one2one, batch)
|
loss_one2one = self.one2one(one2one, batch)
|
||||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||||
|
|
||||||
|
class v10PGTDetectLoss:
|
||||||
|
def __init__(self, model):
|
||||||
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||||
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||||
|
|
||||||
|
def __call__(self, preds, batch):
|
||||||
|
batch['img'] = batch['img'].requires_grad_(True)
|
||||||
|
one2many = preds["one2many"]
|
||||||
|
loss_one2many = self.one2many(one2many, batch)
|
||||||
|
one2one = preds["one2one"]
|
||||||
|
loss_one2one = self.one2one(one2one, batch)
|
||||||
|
|
||||||
|
loss = loss_one2many[0] + loss_one2one[0]
|
||||||
|
|
||||||
|
smask = get_dist_reg(batch['img'], batch['masks'])
|
||||||
|
|
||||||
|
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
|
||||||
|
grad = torch.abs(grad)
|
||||||
|
|
||||||
|
pgt_coeff = 3.0
|
||||||
|
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
|
||||||
|
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||||
|
loss += plaus_loss
|
||||||
|
|
||||||
|
return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
|
||||||
|
|
818
ultralytics/utils/plaus_functs.py
Normal file
818
ultralytics/utils/plaus_functs.py
Normal file
@ -0,0 +1,818 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
# from plot_functs import *
|
||||||
|
from .plot_functs import normalize_tensor, overlay_mask, imshow
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import matplotlib.path as mplPath
|
||||||
|
from matplotlib.path import Path
|
||||||
|
# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
|
||||||
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
|
||||||
|
from .metrics import bbox_iou
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
def plaus_loss_fn(grad, smask, pgt_coeff):
|
||||||
|
################## Compute the PGT Loss ##################
|
||||||
|
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||||
|
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
|
||||||
|
# Negative regularization term for incentivizing pixels far from the target to have low attribution
|
||||||
|
dist_attr_neg = attr_reg(grad, smask)
|
||||||
|
# Calculate plausibility regularization term
|
||||||
|
# dist_reg = dist_attr_pos - dist_attr_neg
|
||||||
|
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
|
||||||
|
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||||
|
# Calculate plausibility loss
|
||||||
|
plaus_loss = (1 - plaus_reg) * pgt_coeff
|
||||||
|
return plaus_loss
|
||||||
|
|
||||||
|
def get_dist_reg(images, seg_mask):
|
||||||
|
seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
|
||||||
|
seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
|
||||||
|
seg_mask[seg_mask > 0] = 1.0
|
||||||
|
|
||||||
|
smask = torch.zeros_like(seg_mask)
|
||||||
|
sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
|
||||||
|
for k_it, sigma in enumerate(sigmas):
|
||||||
|
# Apply Gaussian blur to the mask
|
||||||
|
kernel_size = int(sigma + 50)
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
|
||||||
|
if torch.max(seg_mask1) > 1.0:
|
||||||
|
seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
|
||||||
|
smask = torch.max(smask, seg_mask1)
|
||||||
|
return smask
|
||||||
|
|
||||||
|
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
|
||||||
|
"""
|
||||||
|
Compute the gradient of an image with respect to a given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): The input image tensor.
|
||||||
|
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
|
||||||
|
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
|
||||||
|
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
|
||||||
|
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The computed attribution map.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
|
||||||
|
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
|
||||||
|
else:
|
||||||
|
grad_wrt_outputs = None
|
||||||
|
attribution_map = torch.autograd.grad(grad_wrt, img,
|
||||||
|
grad_outputs=grad_wrt_outputs,
|
||||||
|
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
||||||
|
)[0]
|
||||||
|
if absolute:
|
||||||
|
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
||||||
|
if norm:
|
||||||
|
if keepmean:
|
||||||
|
attmean = torch.mean(attribution_map)
|
||||||
|
attmin = torch.min(attribution_map)
|
||||||
|
attmax = torch.max(attribution_map)
|
||||||
|
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
if keepmean:
|
||||||
|
attribution_map -= attribution_map.mean()
|
||||||
|
attribution_map += (attmean / (attmax - attmin))
|
||||||
|
|
||||||
|
return attribution_map
|
||||||
|
|
||||||
|
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
|
||||||
|
"""
|
||||||
|
Generate Gaussian noise based on the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): Input image.
|
||||||
|
grad_wrt: Gradient with respect to the input image.
|
||||||
|
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
|
||||||
|
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
|
||||||
|
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated Gaussian noise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
gaussian_noise = torch.randn_like(img)
|
||||||
|
|
||||||
|
if absolute:
|
||||||
|
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
|
||||||
|
if norm:
|
||||||
|
if keepmean:
|
||||||
|
attmean = torch.mean(gaussian_noise)
|
||||||
|
attmin = torch.min(gaussian_noise)
|
||||||
|
attmax = torch.max(gaussian_noise)
|
||||||
|
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
|
||||||
|
if keepmean:
|
||||||
|
gaussian_noise -= gaussian_noise.mean()
|
||||||
|
gaussian_noise += (attmean / (attmax - attmin))
|
||||||
|
|
||||||
|
return gaussian_noise
|
||||||
|
|
||||||
|
|
||||||
|
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
||||||
|
# TODO: Remove imgs from this function and only take it as input if debug is True
|
||||||
|
"""
|
||||||
|
Calculates the plausibility score based on the given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): The input images.
|
||||||
|
targets_out (torch.Tensor): The output targets.
|
||||||
|
attr (torch.Tensor): The attribute tensor.
|
||||||
|
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The plausibility score.
|
||||||
|
"""
|
||||||
|
# # if imgs is None:
|
||||||
|
# # imgs = torch.zeros_like(attr)
|
||||||
|
# # with torch.no_grad():
|
||||||
|
# target_inds = targets_out[:, 0].int()
|
||||||
|
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
# co = xyxy_corners
|
||||||
|
# if corners:
|
||||||
|
# co = targets_out[:, 2:6].int()
|
||||||
|
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# # rows = np.arange(co.shape[0])
|
||||||
|
# x1, x2 = co[:,1], co[:,3]
|
||||||
|
# y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
if torch.isnan(attr).any():
|
||||||
|
attr = torch.nan_to_num(attr, nan=0.0)
|
||||||
|
|
||||||
|
coords_map = get_bbox_map(targets_out, attr)
|
||||||
|
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
for i in range(len(coords_map)):
|
||||||
|
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
||||||
|
test_bbox = torch.zeros_like(imgs[i])
|
||||||
|
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
||||||
|
imshow(test_bbox, save_path='figs/test_bbox')
|
||||||
|
if imgs is None:
|
||||||
|
imgs = torch.zeros_like(attr)
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# # att_select = attr[coords_map]
|
||||||
|
# att_select = attr * coords_map.to(torch.float32)
|
||||||
|
# att_total = attr
|
||||||
|
|
||||||
|
# IoU_num = torch.sum(att_select)
|
||||||
|
# IoU_denom = torch.sum(att_total)
|
||||||
|
|
||||||
|
# IoU_ = (IoU_num / IoU_denom)
|
||||||
|
# plaus_score = IoU_
|
||||||
|
|
||||||
|
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
|
||||||
|
|
||||||
|
return plaus_score
|
||||||
|
|
||||||
|
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
||||||
|
# TODO: Remove imgs from this function and only take it as input if debug is True
|
||||||
|
"""
|
||||||
|
Calculates the plausibility score based on the given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): The input images.
|
||||||
|
targets_out (torch.Tensor): The output targets.
|
||||||
|
attr (torch.Tensor): The attribute tensor.
|
||||||
|
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The plausibility score.
|
||||||
|
"""
|
||||||
|
# if imgs is None:
|
||||||
|
# imgs = torch.zeros_like(attr)
|
||||||
|
# with torch.no_grad():
|
||||||
|
target_inds = targets_out[:, 0].int()
|
||||||
|
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
co = xyxy_corners
|
||||||
|
if corners:
|
||||||
|
co = targets_out[:, 2:6].int()
|
||||||
|
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# rows = np.arange(co.shape[0])
|
||||||
|
x1, x2 = co[:,1], co[:,3]
|
||||||
|
y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
if torch.isnan(attr).any():
|
||||||
|
attr = torch.nan_to_num(attr, nan=0.0)
|
||||||
|
if debug:
|
||||||
|
for i in range(len(coords_map)):
|
||||||
|
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
||||||
|
test_bbox = torch.zeros_like(imgs[i])
|
||||||
|
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
||||||
|
imshow(test_bbox, save_path='figs/test_bbox')
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
# att_select = attr[coords_map]
|
||||||
|
# with torch.no_grad():
|
||||||
|
# IoU_num = (torch.sum(attr[coords_map]))
|
||||||
|
# IoU_denom = torch.sum(attr)
|
||||||
|
# IoU_ = (IoU_num / (IoU_denom))
|
||||||
|
|
||||||
|
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
|
||||||
|
co = (xyxy_batch * num_pixels).int()
|
||||||
|
x1 = co[:,1] + 1
|
||||||
|
y1 = co[:,0] + 1
|
||||||
|
# with torch.no_grad():
|
||||||
|
attr_ = torch.sum(attr, 1, keepdim=True)
|
||||||
|
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
|
||||||
|
for ic in range(co.shape[0]):
|
||||||
|
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
|
||||||
|
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
|
||||||
|
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
|
||||||
|
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
|
||||||
|
|
||||||
|
x_0, y_0 = max_indices_2d(attr0[0])
|
||||||
|
x_1, y_1 = max_indices_2d(attr1[0])
|
||||||
|
x_2, y_2 = max_indices_2d(attr2[0])
|
||||||
|
x_3, y_3 = max_indices_2d(attr3[0])
|
||||||
|
|
||||||
|
y_1 += y1[ic]
|
||||||
|
x_2 += x1[ic]
|
||||||
|
x_3 += x1[ic]
|
||||||
|
y_3 += y1[ic]
|
||||||
|
|
||||||
|
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
|
||||||
|
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
|
||||||
|
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
|
||||||
|
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
|
||||||
|
if corners_attr is None:
|
||||||
|
corners_attr = max_corners
|
||||||
|
else:
|
||||||
|
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
|
||||||
|
# corners_attr[ic] = max_corners
|
||||||
|
# corners_attr = attr[:,0,:4,0]
|
||||||
|
corners_attr = corners_attr.view(-1, 4)
|
||||||
|
# corners_attr = torch.stack(corners_attr, dim=0)
|
||||||
|
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
|
||||||
|
plaus_score = IoU_.mean()
|
||||||
|
|
||||||
|
return plaus_score
|
||||||
|
|
||||||
|
def max_indices_2d(x_inp):
|
||||||
|
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
|
||||||
|
torch.max(x_inp,)
|
||||||
|
index = torch.argmax(x_inp)
|
||||||
|
x = index // x_inp.shape[1]
|
||||||
|
y = index % x_inp.shape[1]
|
||||||
|
# x, y = divmod(index.item(), x_inp.shape[1])
|
||||||
|
|
||||||
|
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
|
||||||
|
|
||||||
|
|
||||||
|
def point_in_polygon(poly, grid):
|
||||||
|
# t0 = time.time()
|
||||||
|
num_points = poly.shape[0]
|
||||||
|
j = num_points - 1
|
||||||
|
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
||||||
|
for i in range(num_points):
|
||||||
|
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
|
||||||
|
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
|
||||||
|
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
|
||||||
|
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
||||||
|
j = i
|
||||||
|
# t1 = time.time()
|
||||||
|
# print(f'point in polygon time: {t1-t0}')
|
||||||
|
return oddNodes
|
||||||
|
|
||||||
|
def point_in_polygon_gpu(poly, grid):
|
||||||
|
num_points = poly.shape[0]
|
||||||
|
i = torch.arange(num_points)
|
||||||
|
j = (i - 1) % num_points
|
||||||
|
# Expand dimensions
|
||||||
|
# t0 = time.time()
|
||||||
|
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
|
||||||
|
# t1 = time.time()
|
||||||
|
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
|
||||||
|
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
|
||||||
|
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
|
||||||
|
# t2 = time.time()
|
||||||
|
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
||||||
|
cond = (cond1 | cond2) & cond3
|
||||||
|
# t3 = time.time()
|
||||||
|
# efficiently perform xor using gpu and avoiding cpu as much as possible
|
||||||
|
c = []
|
||||||
|
while len(cond) > 1:
|
||||||
|
if len(cond) % 2 == 1: # odd number of elements
|
||||||
|
c.append(cond[-1])
|
||||||
|
cond = cond[:-1]
|
||||||
|
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
|
||||||
|
for c_ in c:
|
||||||
|
cond = torch.bitwise_xor(cond, c_)
|
||||||
|
oddNodes = cond
|
||||||
|
# t4 = time.time()
|
||||||
|
# for c in cond:
|
||||||
|
# oddNodes = oddNodes ^ c
|
||||||
|
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
|
||||||
|
# print(f'point in polygon time gpu: {t4-t0}')
|
||||||
|
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
||||||
|
return oddNodes
|
||||||
|
|
||||||
|
|
||||||
|
def bitmap_for_polygon(poly, h, w):
|
||||||
|
y = torch.arange(h).to(poly.device).float()
|
||||||
|
x = torch.arange(w).to(poly.device).float()
|
||||||
|
grid_y, grid_x = torch.meshgrid(y, x)
|
||||||
|
grid = torch.stack((grid_x, grid_y), dim=-1)
|
||||||
|
bitmap = point_in_polygon(poly, grid)
|
||||||
|
return bitmap.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def corners_coords(center_xywh):
|
||||||
|
center_x, center_y, w, h = center_xywh
|
||||||
|
x = center_x - w/2
|
||||||
|
y = center_y - h/2
|
||||||
|
return torch.tensor([x, y, x+w, y+h])
|
||||||
|
|
||||||
|
def corners_coords_batch(center_xywh):
|
||||||
|
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
|
||||||
|
w, h = center_xywh[:,2], center_xywh[:,3]
|
||||||
|
x = center_x - w/2
|
||||||
|
y = center_y - h/2
|
||||||
|
return torch.stack([x, y, x+w, y+h], dim=1)
|
||||||
|
|
||||||
|
def normalize_batch(x):
|
||||||
|
"""
|
||||||
|
Normalize a batch of tensors along each channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Normalized tensor of the same shape as the input.
|
||||||
|
"""
|
||||||
|
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
mins[i] = x[i].min()
|
||||||
|
maxs[i] = x[i].max()
|
||||||
|
x_ = (x - mins) / (maxs - mins)
|
||||||
|
|
||||||
|
return x_
|
||||||
|
|
||||||
|
def get_detections(model_clone, img):
|
||||||
|
"""
|
||||||
|
Get detections from a model given an input image and targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to use for detection.
|
||||||
|
img (torch.Tensor): The input image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The detected bounding boxes.
|
||||||
|
"""
|
||||||
|
model_clone.eval() # Set model to evaluation mode
|
||||||
|
# Run inference
|
||||||
|
with torch.no_grad():
|
||||||
|
det_out, out = model_clone(img)
|
||||||
|
|
||||||
|
# model_.train()
|
||||||
|
del img
|
||||||
|
|
||||||
|
return det_out, out
|
||||||
|
|
||||||
|
def get_labels(det_out, imgs, targets, opt):
|
||||||
|
###################### Get predicted labels ######################
|
||||||
|
nb, _, height, width = imgs.shape # batch size, channels, height, width
|
||||||
|
targets_ = targets.clone()
|
||||||
|
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
|
||||||
|
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
|
||||||
|
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
|
||||||
|
pred_labels = []
|
||||||
|
for si, pred in enumerate(o):
|
||||||
|
labels = targets_[targets_[:, 0] == si, 1:]
|
||||||
|
nl = len(labels)
|
||||||
|
predn = pred.clone()
|
||||||
|
# Get the indices that sort the values in column 5 in ascending order
|
||||||
|
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
|
||||||
|
# Apply the sorting indices to the tensor
|
||||||
|
sorted_pred = predn[sort_indices]
|
||||||
|
# Remove predictions with less than 0.1 confidence
|
||||||
|
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
|
||||||
|
sorted_pred = sorted_pred[:n_conf]
|
||||||
|
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
|
||||||
|
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
|
||||||
|
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
|
||||||
|
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
|
preds[:, 2:] /= gn.to(imgs.device) # from pixels
|
||||||
|
pred_labels.append(preds)
|
||||||
|
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
|
||||||
|
|
||||||
|
return pred_labels
|
||||||
|
##################################################################
|
||||||
|
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
def get_center_coords(attr):
|
||||||
|
img_tensor = img_tensor / img_tensor.max()
|
||||||
|
|
||||||
|
# Define a brightness threshold
|
||||||
|
threshold = 0.95
|
||||||
|
|
||||||
|
# Create a binary mask of the bright pixels
|
||||||
|
mask = img_tensor > threshold
|
||||||
|
|
||||||
|
# Get the coordinates of the bright pixels
|
||||||
|
y_coords, x_coords = torch.where(mask)
|
||||||
|
|
||||||
|
# Calculate the centroid of the bright pixels
|
||||||
|
centroid_x = x_coords.float().mean().item()
|
||||||
|
centroid_y = y_coords.float().mean().item()
|
||||||
|
|
||||||
|
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
|
||||||
|
"""
|
||||||
|
Compute the distance grids from each pixel to the target coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr (torch.Tensor): Attribution maps.
|
||||||
|
targets (torch.Tensor): Target coordinates.
|
||||||
|
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
|
||||||
|
debug (bool, optional): Whether to visualize debug information. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Distance grids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Assign the height and width of the input tensor to variables
|
||||||
|
height, width = attr.shape[-1], attr.shape[-2]
|
||||||
|
|
||||||
|
# attr = torch.abs(attr) # Take absolute values of gradients
|
||||||
|
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
|
||||||
|
|
||||||
|
# Create a grid of indices
|
||||||
|
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
|
||||||
|
idx_grid = torch.stack((xx, yy), dim=-1).float()
|
||||||
|
|
||||||
|
# Expand the grid to match the batch size
|
||||||
|
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
|
||||||
|
|
||||||
|
# Initialize a list to store the distance grids
|
||||||
|
dist_grids_ = [[]] * attr.shape[0]
|
||||||
|
|
||||||
|
# Loop over batches
|
||||||
|
for j in range(attr.shape[0]):
|
||||||
|
# Get the rows where the first column is the current unique value
|
||||||
|
rows = targets[targets[:, 0] == j]
|
||||||
|
|
||||||
|
if len(rows) != 0:
|
||||||
|
# Create a tensor for the target coordinates
|
||||||
|
xy = rows[:,2:4] # y, x
|
||||||
|
# Flip the x and y coordinates and scale them to the image size
|
||||||
|
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
|
||||||
|
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
|
||||||
|
|
||||||
|
# Compute the Euclidean distance from each pixel to the target coordinates
|
||||||
|
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
|
||||||
|
|
||||||
|
# Pick the closest distance to any target for each pixel
|
||||||
|
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
|
||||||
|
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
|
||||||
|
else:
|
||||||
|
# Set grid to zero if no targets are present
|
||||||
|
dist_grid = torch.zeros_like(attr[j])
|
||||||
|
|
||||||
|
dist_grids_[j] = dist_grid
|
||||||
|
# Convert the list of distance grids to a tensor for faster computation
|
||||||
|
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
|
||||||
|
if torch.isnan(dist_grids).any():
|
||||||
|
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
for i in range(len(dist_grids)):
|
||||||
|
if ((i % 8) == 0):
|
||||||
|
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
|
||||||
|
imshow(grid_show, save_path='figs/dist_grids')
|
||||||
|
if imgs is None:
|
||||||
|
imgs = torch.zeros_like(attr)
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
|
||||||
|
imshow(img_overlay, save_path='figs/dist_grid_overlay')
|
||||||
|
weighted_attr = (dist_grids[i] * attr[i])
|
||||||
|
imshow(weighted_attr, save_path='figs/weighted_attr')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
return dist_grids
|
||||||
|
|
||||||
|
def attr_reg(attribution_map, distance_map):
|
||||||
|
|
||||||
|
# dist_attr = distance_map * attribution_map
|
||||||
|
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
|
||||||
|
# del distance_map, attribution_map
|
||||||
|
return dist_attr
|
||||||
|
|
||||||
|
def get_bbox_map(targets_out, attr, corners=False):
|
||||||
|
target_inds = targets_out[:, 0].int()
|
||||||
|
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
co = xyxy_corners
|
||||||
|
if corners:
|
||||||
|
co = targets_out[:, 2:6].int()
|
||||||
|
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# rows = np.arange(co.shape[0])
|
||||||
|
x1, x2 = co[:,1], co[:,3]
|
||||||
|
y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
bbox_map = coords_map.to(torch.float32)
|
||||||
|
|
||||||
|
return bbox_map
|
||||||
|
######################################## BCE #######################################
|
||||||
|
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
|
||||||
|
# if imgs is None:
|
||||||
|
# imgs = torch.zeros_like(attribution_map)
|
||||||
|
# Calculate Plausibility IoU with attribution maps
|
||||||
|
# attribution_map.retains_grad = True
|
||||||
|
if not only_loss:
|
||||||
|
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
|
||||||
|
else:
|
||||||
|
plaus_score = torch.tensor(0.0)
|
||||||
|
|
||||||
|
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
|
||||||
|
# Calculate distance regularization
|
||||||
|
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
|
||||||
|
# distance_map = torch.ones_like(attribution_map)
|
||||||
|
|
||||||
|
if opt.dist_x_bbox:
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
|
||||||
|
distance_map[bbox_map] = 0.0
|
||||||
|
# distance_map = distance_map * (1 - bbox_map)
|
||||||
|
|
||||||
|
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||||
|
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
|
||||||
|
# Negative regularization term for incentivizing pixels far from the target to have low attribution
|
||||||
|
dist_attr_neg = attr_reg(attribution_map, distance_map)
|
||||||
|
# Calculate plausibility regularization term
|
||||||
|
# dist_reg = dist_attr_pos - dist_attr_neg
|
||||||
|
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
|
||||||
|
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
|
||||||
|
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
|
||||||
|
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
|
||||||
|
# / 2.5
|
||||||
|
|
||||||
|
if opt.bbox_coeff != 0.0:
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map)
|
||||||
|
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
|
||||||
|
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
|
||||||
|
bbox_reg = attr_bbox_pos - attr_bbox_neg
|
||||||
|
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
|
||||||
|
else:
|
||||||
|
bbox_reg = 0.0
|
||||||
|
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map)
|
||||||
|
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
|
||||||
|
# iou_loss = (1.0 - plaus_score)
|
||||||
|
|
||||||
|
if not opt.dist_reg_only:
|
||||||
|
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
|
||||||
|
plaus_reg = (plaus_score * opt.iou_coeff) + \
|
||||||
|
(((dist_reg_loss * opt.dist_coeff) + \
|
||||||
|
(bbox_reg * opt.bbox_coeff))\
|
||||||
|
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
|
||||||
|
# / (plaus_score) \
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||||
|
# plaus_reg = dist_reg
|
||||||
|
# Calculate plausibility loss
|
||||||
|
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
|
||||||
|
# plaus_loss = (plaus_reg) * opt.pgt_coeff
|
||||||
|
if only_loss:
|
||||||
|
return plaus_loss
|
||||||
|
if not debug:
|
||||||
|
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
|
||||||
|
else:
|
||||||
|
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
|
||||||
|
|
||||||
|
####################################################################################
|
||||||
|
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
|
||||||
|
####################################################################################
|
||||||
|
|
||||||
|
def generate_vanilla_grad(model, input_tensor, loss_func = None,
|
||||||
|
targets_list=None, targets=None, metric=None, out_num = 1,
|
||||||
|
n_max_labels=3, norm=True, abs=True, grayscale=True,
|
||||||
|
class_specific_attr = True, device='cpu'):
|
||||||
|
"""
|
||||||
|
Generate vanilla gradients for the given model and input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to generate gradients for.
|
||||||
|
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
|
||||||
|
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
|
||||||
|
targets_list (list, optional): The list of target tensors. Defaults to None.
|
||||||
|
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
|
||||||
|
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
|
||||||
|
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
|
||||||
|
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
|
||||||
|
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
|
||||||
|
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
|
||||||
|
device (str, optional): The device to use for computation. Defaults to 'cpu'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The generated vanilla gradients.
|
||||||
|
"""
|
||||||
|
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
|
||||||
|
train_mode = model.training
|
||||||
|
if not train_mode:
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
|
||||||
|
model.zero_grad() # Zero gradients
|
||||||
|
inpt = input_tensor
|
||||||
|
# Forward pass
|
||||||
|
train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
||||||
|
|
||||||
|
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
|
||||||
|
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
|
||||||
|
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
|
||||||
|
|
||||||
|
if class_specific_attr:
|
||||||
|
n_attr_list, index_classes = [], []
|
||||||
|
for i in range(len(input_tensor)):
|
||||||
|
if len(targets_list[i]) > n_max_labels:
|
||||||
|
targets_list[i] = targets_list[i][:n_max_labels]
|
||||||
|
if targets_list[i].numel() != 0:
|
||||||
|
# unique_classes = torch.unique(targets_list[i][:,1])
|
||||||
|
class_numbers = targets_list[i][:,1]
|
||||||
|
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
|
||||||
|
num_attrs = len(targets_list[i])
|
||||||
|
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
|
||||||
|
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
|
||||||
|
n_attr_list.append(num_attrs)
|
||||||
|
else:
|
||||||
|
index_classes.append([0, 1, 2, 3, 4])
|
||||||
|
n_attr_list.append(0)
|
||||||
|
|
||||||
|
targets_list_filled = [targ.clone().detach() for targ in targets_list]
|
||||||
|
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
|
||||||
|
max_labels = np.max(labels_len)
|
||||||
|
max_index = np.argmax(labels_len)
|
||||||
|
for i in range(len(targets_list)):
|
||||||
|
# targets_list_filled[i] = targets_list[i]
|
||||||
|
if len(targets_list_filled[i]) < max_labels:
|
||||||
|
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
|
||||||
|
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
|
||||||
|
else:
|
||||||
|
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
|
||||||
|
for i in range(len(targets_list_filled)-1,-1,-1):
|
||||||
|
if targets_list_filled[i].numel() == 0:
|
||||||
|
targets_list_filled.pop(i)
|
||||||
|
targets_list_filled = torch.cat(targets_list_filled)
|
||||||
|
|
||||||
|
n_img_attrs = len(input_tensor) if class_specific_attr else 1
|
||||||
|
n_img_attrs = 1 if loss_func else n_img_attrs
|
||||||
|
|
||||||
|
attrs_batch = []
|
||||||
|
for i_batch in range(n_img_attrs):
|
||||||
|
if loss_func and class_specific_attr:
|
||||||
|
i_batch = max_index
|
||||||
|
# inpt = input_tensor[i_batch].unsqueeze(0)
|
||||||
|
# ##################################################################
|
||||||
|
# model.zero_grad() # Zero gradients
|
||||||
|
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
||||||
|
# ##################################################################
|
||||||
|
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
|
||||||
|
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
|
||||||
|
attrs_img = []
|
||||||
|
for i_attr in range(n_label_attrs):
|
||||||
|
if loss_func is None:
|
||||||
|
grad_wrt = train_out[out_num]
|
||||||
|
if class_specific_attr:
|
||||||
|
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
|
||||||
|
grad_wrt_outputs = torch.ones_like(grad_wrt)
|
||||||
|
else:
|
||||||
|
# if class_specific_attr:
|
||||||
|
# targets = targets_list[:][i_attr]
|
||||||
|
# n_targets = len(targets_list[i_batch])
|
||||||
|
if class_specific_attr:
|
||||||
|
target_indiv = targets_list_filled[:,i_attr] # batch image input
|
||||||
|
else:
|
||||||
|
target_indiv = targets
|
||||||
|
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
|
||||||
|
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
|
||||||
|
except:
|
||||||
|
target_indiv = target_indiv.to(device)
|
||||||
|
inpt = inpt.to(device)
|
||||||
|
for tro in train_out:
|
||||||
|
tro = tro.to(device)
|
||||||
|
print("Error in loss function, trying again with device specified")
|
||||||
|
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
|
||||||
|
grad_wrt = loss
|
||||||
|
grad_wrt_outputs = None
|
||||||
|
|
||||||
|
model.zero_grad() # Zero gradients
|
||||||
|
gradients = torch.autograd.grad(grad_wrt, inpt,
|
||||||
|
grad_outputs=grad_wrt_outputs,
|
||||||
|
retain_graph=True,
|
||||||
|
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert gradients to numpy array and back to ensure full separation from graph
|
||||||
|
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
|
||||||
|
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
|
||||||
|
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
||||||
|
if abs:
|
||||||
|
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
|
||||||
|
if norm:
|
||||||
|
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
attrs_img.append(attribution_map)
|
||||||
|
if len(attrs_img) == 0:
|
||||||
|
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
|
||||||
|
else:
|
||||||
|
attrs_batch.append(torch.stack(attrs_img).to(device))
|
||||||
|
|
||||||
|
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
|
||||||
|
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
|
||||||
|
out_attr = attrs_batch
|
||||||
|
# Set model back to original mode
|
||||||
|
if not train_mode:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return out_attr
|
||||||
|
|
||||||
|
class RVNonLinearFunc(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Custom Bayesian ReLU activation function for random variables.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
def __init__(self, func):
|
||||||
|
super(RVNonLinearFunc, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, mu_in, Sigma_in):
|
||||||
|
"""
|
||||||
|
Forward pass of the Bayesian ReLU activation function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
|
||||||
|
representing the mean input to the ReLU activation function.
|
||||||
|
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
|
||||||
|
representing the covariance input to the ReLU activation function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
|
||||||
|
including the mean of the output and the covariance of the output.
|
||||||
|
"""
|
||||||
|
# Collect stats
|
||||||
|
batch_size = mu_in.size(0)
|
||||||
|
|
||||||
|
# Mean
|
||||||
|
mu_out = self.func(mu_in)
|
||||||
|
|
||||||
|
# Compute the derivative of the ReLU activation function with respect to the input mean
|
||||||
|
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
|
||||||
|
|
||||||
|
# add an extra dimension to gradi at position 2 and 1
|
||||||
|
grad1 = gradi.unsqueeze(dim=2)
|
||||||
|
grad2 = gradi.unsqueeze(dim=1)
|
||||||
|
|
||||||
|
# compute the outer product of grad1 and grad2
|
||||||
|
outer_product = torch.bmm(grad1, grad2)
|
||||||
|
|
||||||
|
# element-wise multiply Sigma_in with the outer product
|
||||||
|
# and return the result
|
||||||
|
Sigma_out = torch.mul(Sigma_in, outer_product)
|
||||||
|
|
||||||
|
return mu_out, Sigma_out
|
||||||
|
|
154
ultralytics/utils/plot_functs.py
Normal file
154
ultralytics/utils/plot_functs.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class Subplots:
|
||||||
|
def __init__(self, figsize = (40, 5)):
|
||||||
|
self.fig = plt.figure(figsize=figsize)
|
||||||
|
|
||||||
|
def plot_img_list(self, img_list, savedir='figs/test',
|
||||||
|
nrows = 1, rownum = 0,
|
||||||
|
hold = False, coltitles=[], rowtitle=''):
|
||||||
|
|
||||||
|
for i, img in enumerate(img_list):
|
||||||
|
try:
|
||||||
|
npimg = img.clone().detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
npimg = img
|
||||||
|
tpimg = np.transpose(npimg, (1, 2, 0))
|
||||||
|
lenrow = int((len(img_list)))
|
||||||
|
ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
|
||||||
|
if len(coltitles) > i:
|
||||||
|
ax.set_title(coltitles[i])
|
||||||
|
if i == 0:
|
||||||
|
ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
|
||||||
|
xycoords='axes fraction', textcoords='offset points',
|
||||||
|
size='large', ha='center', va='baseline')
|
||||||
|
# ax.set_ylabel(rowtitle, rotation=90)
|
||||||
|
ax.imshow(tpimg)
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
if not hold:
|
||||||
|
self.fig.tight_layout()
|
||||||
|
plt.savefig(f'{savedir}.png')
|
||||||
|
plt.clf()
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
|
||||||
|
def VisualizeNumpyImageGrayscale(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
vmin = np.min(image_3d)
|
||||||
|
image_2d = image_3d - vmin
|
||||||
|
vmax = np.max(image_2d)
|
||||||
|
return (image_2d / vmax)
|
||||||
|
|
||||||
|
def normalize_numpy(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
vmin = np.min(image_3d)
|
||||||
|
image_2d = image_3d - vmin
|
||||||
|
vmax = np.max(image_2d)
|
||||||
|
return (image_2d / vmax)
|
||||||
|
|
||||||
|
# def normalize_tensor(image_3d):
|
||||||
|
# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
# """
|
||||||
|
# vmin = torch.min(image_3d)
|
||||||
|
# image_2d = image_3d - vmin
|
||||||
|
# vmax = torch.max(image_2d)
|
||||||
|
# return (image_2d / vmax)
|
||||||
|
|
||||||
|
def normalize_tensor(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
image_2d = (image_3d - torch.min(image_3d))
|
||||||
|
return (image_2d / torch.max(image_2d))
|
||||||
|
|
||||||
|
def format_img(img_):
|
||||||
|
np_img = img_.numpy()
|
||||||
|
tp_img = np.transpose(np_img, (1, 2, 0))
|
||||||
|
return tp_img
|
||||||
|
|
||||||
|
def imshow(img, save_path=None):
|
||||||
|
try:
|
||||||
|
npimg = img.clone().detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
npimg = img
|
||||||
|
tpimg = np.transpose(npimg, (1, 2, 0))
|
||||||
|
plt.imshow(tpimg)
|
||||||
|
# plt.axis('off')
|
||||||
|
plt.tight_layout()
|
||||||
|
if save_path != None:
|
||||||
|
plt.savefig(str(str(save_path) + ".png"))
|
||||||
|
#plt.show()a
|
||||||
|
|
||||||
|
def imshow_img(img, imsave_path):
|
||||||
|
# works for tensors and numpy arrays
|
||||||
|
try:
|
||||||
|
npimg = VisualizeNumpyImageGrayscale(img.numpy())
|
||||||
|
except:
|
||||||
|
npimg = VisualizeNumpyImageGrayscale(img)
|
||||||
|
npimg = np.transpose(npimg, (2, 0, 1))
|
||||||
|
imshow(npimg, save_path=imsave_path)
|
||||||
|
print("Saving image as ", imsave_path)
|
||||||
|
|
||||||
|
def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
|
||||||
|
model.train()
|
||||||
|
model.to(device)
|
||||||
|
img = img.to(device)
|
||||||
|
img.requires_grad_(True)
|
||||||
|
labels.to(device).requires_grad_(True)
|
||||||
|
model.requires_grad_(True)
|
||||||
|
cuda = device.type != 'cpu'
|
||||||
|
scaler = amp.GradScaler(enabled=cuda)
|
||||||
|
pred = model(img)
|
||||||
|
# out, train_out = model(img, augment=augment) # inference and training outputs
|
||||||
|
loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
|
||||||
|
# loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
|
||||||
|
# loss = torch.sum(loss).requires_grad_(True)
|
||||||
|
|
||||||
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
scaler.scale(loss).backward(inputs=img)
|
||||||
|
# loss.backward()
|
||||||
|
|
||||||
|
# S_c = torch.max(pred[0].data, 0)[0]
|
||||||
|
Sc_dx = img.grad
|
||||||
|
model.eval()
|
||||||
|
Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
|
||||||
|
return Sc_dx
|
||||||
|
|
||||||
|
def calculate_snr(img, attr, dB=True):
|
||||||
|
try:
|
||||||
|
img_np = img.detach().cpu().numpy()
|
||||||
|
attr_np = attr.detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
img_np = img
|
||||||
|
attr_np = attr
|
||||||
|
|
||||||
|
# Calculate the signal power
|
||||||
|
signal_power = np.mean(img_np**2)
|
||||||
|
|
||||||
|
# Calculate the noise power
|
||||||
|
noise_power = np.mean(attr_np**2)
|
||||||
|
|
||||||
|
if dB == True:
|
||||||
|
# Calculate SNR in dB
|
||||||
|
snr = 10 * np.log10(signal_power / noise_power)
|
||||||
|
else:
|
||||||
|
# Calculate SNR
|
||||||
|
snr = signal_power / noise_power
|
||||||
|
|
||||||
|
return snr
|
||||||
|
|
||||||
|
def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
|
||||||
|
|
||||||
|
cmap = plt.get_cmap(colormap)
|
||||||
|
npmask = np.array(mask.clone().detach().cpu().squeeze(0))
|
||||||
|
# cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
|
||||||
|
cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
|
||||||
|
overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
|
||||||
|
overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
|
||||||
|
|
||||||
|
return overlayed_tensor
|
@ -717,7 +717,7 @@ def plot_images(
|
|||||||
):
|
):
|
||||||
"""Plot image grid with labels."""
|
"""Plot image grid with labels."""
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
images = images.cpu().float().numpy()
|
images = images.detach().cpu().float().numpy()
|
||||||
if isinstance(cls, torch.Tensor):
|
if isinstance(cls, torch.Tensor):
|
||||||
cls = cls.cpu().numpy()
|
cls = cls.cpu().numpy()
|
||||||
if isinstance(bboxes, torch.Tensor):
|
if isinstance(bboxes, torch.Tensor):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user