yolov10/ultralytics/utils/plaus_functs.py

819 lines
36 KiB
Python

import torch
import numpy as np
# from plot_functs import *
from .plot_functs import normalize_tensor, overlay_mask, imshow
import math
import time
import matplotlib.path as mplPath
from matplotlib.path import Path
# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
from .metrics import bbox_iou
import torchvision.transforms as T
def plaus_loss_fn(grad, smask, pgt_coeff):
################## Compute the PGT Loss ##################
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(grad, smask)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
plaus_reg = (((1.0 + dist_reg) / 2.0))
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * pgt_coeff
return plaus_loss
def get_dist_reg(images, seg_mask):
seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
seg_mask[seg_mask > 0] = 1.0
smask = torch.zeros_like(seg_mask)
sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
for k_it, sigma in enumerate(sigmas):
# Apply Gaussian blur to the mask
kernel_size = int(sigma + 50)
if kernel_size % 2 == 0:
kernel_size += 1
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
if torch.max(seg_mask1) > 1.0:
seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
smask = torch.max(smask, seg_mask1)
return smask
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
"""
Compute the gradient of an image with respect to a given tensor.
Args:
img (torch.Tensor): The input image tensor.
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
Returns:
torch.Tensor: The computed attribution map.
"""
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
else:
grad_wrt_outputs = None
attribution_map = torch.autograd.grad(grad_wrt, img,
grad_outputs=grad_wrt_outputs,
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)[0]
if absolute:
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(attribution_map)
attmin = torch.min(attribution_map)
attmax = torch.max(attribution_map)
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
if keepmean:
attribution_map -= attribution_map.mean()
attribution_map += (attmean / (attmax - attmin))
return attribution_map
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
"""
Generate Gaussian noise based on the input image.
Args:
img (torch.Tensor): Input image.
grad_wrt: Gradient with respect to the input image.
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
Returns:
torch.Tensor: Generated Gaussian noise.
"""
gaussian_noise = torch.randn_like(img)
if absolute:
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
if norm:
if keepmean:
attmean = torch.mean(gaussian_noise)
attmin = torch.min(gaussian_noise)
attmax = torch.max(gaussian_noise)
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
if keepmean:
gaussian_noise -= gaussian_noise.mean()
gaussian_noise += (attmean / (attmax - attmin))
return gaussian_noise
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# # if imgs is None:
# # imgs = torch.zeros_like(attr)
# # with torch.no_grad():
# target_inds = targets_out[:, 0].int()
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
# co = xyxy_corners
# if corners:
# co = targets_out[:, 2:6].int()
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
# # rows = np.arange(co.shape[0])
# x1, x2 = co[:,1], co[:,3]
# y1, y2 = co[:,0], co[:,2]
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
coords_map = get_bbox_map(targets_out, attr)
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# with torch.no_grad():
# # att_select = attr[coords_map]
# att_select = attr * coords_map.to(torch.float32)
# att_total = attr
# IoU_num = torch.sum(att_select)
# IoU_denom = torch.sum(att_total)
# IoU_ = (IoU_num / IoU_denom)
# plaus_score = IoU_
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
return plaus_score
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
# TODO: Remove imgs from this function and only take it as input if debug is True
"""
Calculates the plausibility score based on the given inputs.
Args:
imgs (torch.Tensor): The input images.
targets_out (torch.Tensor): The output targets.
attr (torch.Tensor): The attribute tensor.
debug (bool, optional): Whether to enable debug mode. Defaults to False.
Returns:
torch.Tensor: The plausibility score.
"""
# if imgs is None:
# imgs = torch.zeros_like(attr)
# with torch.no_grad():
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
if torch.isnan(attr).any():
attr = torch.nan_to_num(attr, nan=0.0)
if debug:
for i in range(len(coords_map)):
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
test_bbox = torch.zeros_like(imgs[i])
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
imshow(test_bbox, save_path='figs/test_bbox')
imshow(imgs[i], save_path='figs/im0')
imshow(attr[i], save_path='figs/attr')
# att_select = attr[coords_map]
# with torch.no_grad():
# IoU_num = (torch.sum(attr[coords_map]))
# IoU_denom = torch.sum(attr)
# IoU_ = (IoU_num / (IoU_denom))
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
co = (xyxy_batch * num_pixels).int()
x1 = co[:,1] + 1
y1 = co[:,0] + 1
# with torch.no_grad():
attr_ = torch.sum(attr, 1, keepdim=True)
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
for ic in range(co.shape[0]):
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
x_0, y_0 = max_indices_2d(attr0[0])
x_1, y_1 = max_indices_2d(attr1[0])
x_2, y_2 = max_indices_2d(attr2[0])
x_3, y_3 = max_indices_2d(attr3[0])
y_1 += y1[ic]
x_2 += x1[ic]
x_3 += x1[ic]
y_3 += y1[ic]
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
if corners_attr is None:
corners_attr = max_corners
else:
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
# corners_attr[ic] = max_corners
# corners_attr = attr[:,0,:4,0]
corners_attr = corners_attr.view(-1, 4)
# corners_attr = torch.stack(corners_attr, dim=0)
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
plaus_score = IoU_.mean()
return plaus_score
def max_indices_2d(x_inp):
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
torch.max(x_inp,)
index = torch.argmax(x_inp)
x = index // x_inp.shape[1]
y = index % x_inp.shape[1]
# x, y = divmod(index.item(), x_inp.shape[1])
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
def point_in_polygon(poly, grid):
# t0 = time.time()
num_points = poly.shape[0]
j = num_points - 1
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
for i in range(num_points):
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
j = i
# t1 = time.time()
# print(f'point in polygon time: {t1-t0}')
return oddNodes
def point_in_polygon_gpu(poly, grid):
num_points = poly.shape[0]
i = torch.arange(num_points)
j = (i - 1) % num_points
# Expand dimensions
# t0 = time.time()
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
# t1 = time.time()
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
# t2 = time.time()
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
cond = (cond1 | cond2) & cond3
# t3 = time.time()
# efficiently perform xor using gpu and avoiding cpu as much as possible
c = []
while len(cond) > 1:
if len(cond) % 2 == 1: # odd number of elements
c.append(cond[-1])
cond = cond[:-1]
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
for c_ in c:
cond = torch.bitwise_xor(cond, c_)
oddNodes = cond
# t4 = time.time()
# for c in cond:
# oddNodes = oddNodes ^ c
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
# print(f'point in polygon time gpu: {t4-t0}')
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
return oddNodes
def bitmap_for_polygon(poly, h, w):
y = torch.arange(h).to(poly.device).float()
x = torch.arange(w).to(poly.device).float()
grid_y, grid_x = torch.meshgrid(y, x)
grid = torch.stack((grid_x, grid_y), dim=-1)
bitmap = point_in_polygon(poly, grid)
return bitmap.unsqueeze(0)
def corners_coords(center_xywh):
center_x, center_y, w, h = center_xywh
x = center_x - w/2
y = center_y - h/2
return torch.tensor([x, y, x+w, y+h])
def corners_coords_batch(center_xywh):
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
w, h = center_xywh[:,2], center_xywh[:,3]
x = center_x - w/2
y = center_y - h/2
return torch.stack([x, y, x+w, y+h], dim=1)
def normalize_batch(x):
"""
Normalize a batch of tensors along each channel.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Normalized tensor of the same shape as the input.
"""
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
for i in range(x.shape[0]):
mins[i] = x[i].min()
maxs[i] = x[i].max()
x_ = (x - mins) / (maxs - mins)
return x_
def get_detections(model_clone, img):
"""
Get detections from a model given an input image and targets.
Args:
model (nn.Module): The model to use for detection.
img (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The detected bounding boxes.
"""
model_clone.eval() # Set model to evaluation mode
# Run inference
with torch.no_grad():
det_out, out = model_clone(img)
# model_.train()
del img
return det_out, out
def get_labels(det_out, imgs, targets, opt):
###################### Get predicted labels ######################
nb, _, height, width = imgs.shape # batch size, channels, height, width
targets_ = targets.clone()
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
pred_labels = []
for si, pred in enumerate(o):
labels = targets_[targets_[:, 0] == si, 1:]
nl = len(labels)
predn = pred.clone()
# Get the indices that sort the values in column 5 in ascending order
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
# Apply the sorting indices to the tensor
sorted_pred = predn[sort_indices]
# Remove predictions with less than 0.1 confidence
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
sorted_pred = sorted_pred[:n_conf]
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
preds[:, 2:] /= gn.to(imgs.device) # from pixels
pred_labels.append(preds)
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
return pred_labels
##################################################################
from torchvision.utils import make_grid
def get_center_coords(attr):
img_tensor = img_tensor / img_tensor.max()
# Define a brightness threshold
threshold = 0.95
# Create a binary mask of the bright pixels
mask = img_tensor > threshold
# Get the coordinates of the bright pixels
y_coords, x_coords = torch.where(mask)
# Calculate the centroid of the bright pixels
centroid_x = x_coords.float().mean().item()
centroid_y = y_coords.float().mean().item()
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
return
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
"""
Compute the distance grids from each pixel to the target coordinates.
Args:
attr (torch.Tensor): Attribution maps.
targets (torch.Tensor): Target coordinates.
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
debug (bool, optional): Whether to visualize debug information. Defaults to False.
Returns:
torch.Tensor: Distance grids.
"""
# Assign the height and width of the input tensor to variables
height, width = attr.shape[-1], attr.shape[-2]
# attr = torch.abs(attr) # Take absolute values of gradients
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
# Create a grid of indices
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
idx_grid = torch.stack((xx, yy), dim=-1).float()
# Expand the grid to match the batch size
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
# Initialize a list to store the distance grids
dist_grids_ = [[]] * attr.shape[0]
# Loop over batches
for j in range(attr.shape[0]):
# Get the rows where the first column is the current unique value
rows = targets[targets[:, 0] == j]
if len(rows) != 0:
# Create a tensor for the target coordinates
xy = rows[:,2:4] # y, x
# Flip the x and y coordinates and scale them to the image size
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
# Compute the Euclidean distance from each pixel to the target coordinates
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
# Pick the closest distance to any target for each pixel
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
else:
# Set grid to zero if no targets are present
dist_grid = torch.zeros_like(attr[j])
dist_grids_[j] = dist_grid
# Convert the list of distance grids to a tensor for faster computation
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
if torch.isnan(dist_grids).any():
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
if debug:
for i in range(len(dist_grids)):
if ((i % 8) == 0):
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
imshow(grid_show, save_path='figs/dist_grids')
if imgs is None:
imgs = torch.zeros_like(attr)
imshow(imgs[i], save_path='figs/im0')
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
imshow(img_overlay, save_path='figs/dist_grid_overlay')
weighted_attr = (dist_grids[i] * attr[i])
imshow(weighted_attr, save_path='figs/weighted_attr')
imshow(attr[i], save_path='figs/attr')
return dist_grids
def attr_reg(attribution_map, distance_map):
# dist_attr = distance_map * attribution_map
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
# del distance_map, attribution_map
return dist_attr
def get_bbox_map(targets_out, attr, corners=False):
target_inds = targets_out[:, 0].int()
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
co = xyxy_corners
if corners:
co = targets_out[:, 2:6].int()
coords_map = torch.zeros_like(attr, dtype=torch.bool)
# rows = np.arange(co.shape[0])
x1, x2 = co[:,1], co[:,3]
y1, y2 = co[:,0], co[:,2]
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
bbox_map = coords_map.to(torch.float32)
return bbox_map
######################################## BCE #######################################
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
# if imgs is None:
# imgs = torch.zeros_like(attribution_map)
# Calculate Plausibility IoU with attribution maps
# attribution_map.retains_grad = True
if not only_loss:
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
else:
plaus_score = torch.tensor(0.0)
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
# Calculate distance regularization
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
# distance_map = torch.ones_like(attribution_map)
if opt.dist_x_bbox:
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
distance_map[bbox_map] = 0.0
# distance_map = distance_map * (1 - bbox_map)
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
# Negative regularization term for incentivizing pixels far from the target to have low attribution
dist_attr_neg = attr_reg(attribution_map, distance_map)
# Calculate plausibility regularization term
# dist_reg = dist_attr_pos - dist_attr_neg
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
# / 2.5
if opt.bbox_coeff != 0.0:
bbox_map = get_bbox_map(targets, attribution_map)
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
bbox_reg = attr_bbox_pos - attr_bbox_neg
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
else:
bbox_reg = 0.0
bbox_map = get_bbox_map(targets, attribution_map)
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
# iou_loss = (1.0 - plaus_score)
if not opt.dist_reg_only:
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
plaus_reg = (plaus_score * opt.iou_coeff) + \
(((dist_reg_loss * opt.dist_coeff) + \
(bbox_reg * opt.bbox_coeff))\
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
# / (plaus_score) \
)
else:
plaus_reg = (((1.0 + dist_reg) / 2.0))
# plaus_reg = dist_reg
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
# plaus_loss = (plaus_reg) * opt.pgt_coeff
if only_loss:
return plaus_loss
if not debug:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
else:
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
####################################################################################
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
####################################################################################
def generate_vanilla_grad(model, input_tensor, loss_func = None,
targets_list=None, targets=None, metric=None, out_num = 1,
n_max_labels=3, norm=True, abs=True, grayscale=True,
class_specific_attr = True, device='cpu'):
"""
Generate vanilla gradients for the given model and input tensor.
Args:
model (nn.Module): The model to generate gradients for.
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
targets_list (list, optional): The list of target tensors. Defaults to None.
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
device (str, optional): The device to use for computation. Defaults to 'cpu'.
Returns:
torch.Tensor: The generated vanilla gradients.
"""
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
train_mode = model.training
if not train_mode:
model.train()
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
model.zero_grad() # Zero gradients
inpt = input_tensor
# Forward pass
train_out = model(inpt) # training outputs (no inference outputs in train mode)
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
if class_specific_attr:
n_attr_list, index_classes = [], []
for i in range(len(input_tensor)):
if len(targets_list[i]) > n_max_labels:
targets_list[i] = targets_list[i][:n_max_labels]
if targets_list[i].numel() != 0:
# unique_classes = torch.unique(targets_list[i][:,1])
class_numbers = targets_list[i][:,1]
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
num_attrs = len(targets_list[i])
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
n_attr_list.append(num_attrs)
else:
index_classes.append([0, 1, 2, 3, 4])
n_attr_list.append(0)
targets_list_filled = [targ.clone().detach() for targ in targets_list]
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
max_labels = np.max(labels_len)
max_index = np.argmax(labels_len)
for i in range(len(targets_list)):
# targets_list_filled[i] = targets_list[i]
if len(targets_list_filled[i]) < max_labels:
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
else:
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
for i in range(len(targets_list_filled)-1,-1,-1):
if targets_list_filled[i].numel() == 0:
targets_list_filled.pop(i)
targets_list_filled = torch.cat(targets_list_filled)
n_img_attrs = len(input_tensor) if class_specific_attr else 1
n_img_attrs = 1 if loss_func else n_img_attrs
attrs_batch = []
for i_batch in range(n_img_attrs):
if loss_func and class_specific_attr:
i_batch = max_index
# inpt = input_tensor[i_batch].unsqueeze(0)
# ##################################################################
# model.zero_grad() # Zero gradients
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
# ##################################################################
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
attrs_img = []
for i_attr in range(n_label_attrs):
if loss_func is None:
grad_wrt = train_out[out_num]
if class_specific_attr:
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
grad_wrt_outputs = torch.ones_like(grad_wrt)
else:
# if class_specific_attr:
# targets = targets_list[:][i_attr]
# n_targets = len(targets_list[i_batch])
if class_specific_attr:
target_indiv = targets_list_filled[:,i_attr] # batch image input
else:
target_indiv = targets
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
try:
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
except:
target_indiv = target_indiv.to(device)
inpt = inpt.to(device)
for tro in train_out:
tro = tro.to(device)
print("Error in loss function, trying again with device specified")
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
grad_wrt = loss
grad_wrt_outputs = None
model.zero_grad() # Zero gradients
gradients = torch.autograd.grad(grad_wrt, inpt,
grad_outputs=grad_wrt_outputs,
retain_graph=True,
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
)
# Convert gradients to numpy array and back to ensure full separation from graph
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
if abs:
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
if norm:
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
attrs_img.append(attribution_map)
if len(attrs_img) == 0:
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
else:
attrs_batch.append(torch.stack(attrs_img).to(device))
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
out_attr = attrs_batch
# Set model back to original mode
if not train_mode:
model.eval()
return out_attr
class RVNonLinearFunc(torch.nn.Module):
"""
Custom Bayesian ReLU activation function for random variables.
Attributes:
None
"""
def __init__(self, func):
super(RVNonLinearFunc, self).__init__()
self.func = func
def forward(self, mu_in, Sigma_in):
"""
Forward pass of the Bayesian ReLU activation function.
Args:
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
representing the mean input to the ReLU activation function.
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
representing the covariance input to the ReLU activation function.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
including the mean of the output and the covariance of the output.
"""
# Collect stats
batch_size = mu_in.size(0)
# Mean
mu_out = self.func(mu_in)
# Compute the derivative of the ReLU activation function with respect to the input mean
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
# add an extra dimension to gradi at position 2 and 1
grad1 = gradi.unsqueeze(dim=2)
grad2 = gradi.unsqueeze(dim=1)
# compute the outer product of grad1 and grad2
outer_product = torch.bmm(grad1, grad2)
# element-wise multiply Sigma_in with the outer product
# and return the result
Sigma_out = torch.mul(Sigma_in, outer_product)
return mu_out, Sigma_out