import torch import numpy as np # from plot_functs import * from .plot_functs import normalize_tensor, overlay_mask, imshow import math import time import matplotlib.path as mplPath from matplotlib.path import Path # from utils.general import non_max_suppression, xyxy2xywh, scale_coords from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression from .metrics import bbox_iou import torchvision.transforms as T def plaus_loss_fn(grad, smask, pgt_coeff): ################## Compute the PGT Loss ################## # Positive regularization term for incentivizing pixels near the target to have high attribution dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask # Negative regularization term for incentivizing pixels far from the target to have low attribution dist_attr_neg = attr_reg(grad, smask) # Calculate plausibility regularization term # dist_reg = dist_attr_pos - dist_attr_neg dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad))) plaus_reg = (((1.0 + dist_reg) / 2.0)) # Calculate plausibility loss plaus_loss = (1 - plaus_reg) * pgt_coeff return plaus_loss def get_dist_reg(images, seg_mask): seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device) seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1) seg_mask[seg_mask > 0] = 1.0 smask = torch.zeros_like(seg_mask) sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)] for k_it, sigma in enumerate(sigmas): # Apply Gaussian blur to the mask kernel_size = int(sigma + 50) if kernel_size % 2 == 0: kernel_size += 1 seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask) if torch.max(seg_mask1) > 1.0: # seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min()) seg_mask1 = normalize_tensor(seg_mask1) smask = torch.max(smask, seg_mask1) smask = normalize_tensor(smask) return smask def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False): """ Compute the gradient of an image with respect to a given tensor. Args: img (torch.Tensor): The input image tensor. grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed. norm (bool, optional): Whether to normalize the gradient. Defaults to True. absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True. keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False. Returns: torch.Tensor: The computed attribution map. """ if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])): grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True) else: grad_wrt_outputs = None attribution_map = torch.autograd.grad(grad_wrt, img, grad_outputs=grad_wrt_outputs, create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly )[0] if absolute: attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval attribution_map = torch.sum(attribution_map, 1, keepdim=True) if norm: if keepmean: attmean = torch.mean(attribution_map) attmin = torch.min(attribution_map) attmax = torch.max(attribution_map) attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch if keepmean: attribution_map -= attribution_map.mean() attribution_map += (attmean / (attmax - attmin)) return attribution_map def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False): """ Generate Gaussian noise based on the input image. Args: img (torch.Tensor): Input image. grad_wrt: Gradient with respect to the input image. norm (bool, optional): Whether to normalize the generated noise. Defaults to True. absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True. keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False. Returns: torch.Tensor: Generated Gaussian noise. """ gaussian_noise = torch.randn_like(img) if absolute: gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True) if norm: if keepmean: attmean = torch.mean(gaussian_noise) attmin = torch.min(gaussian_noise) attmax = torch.max(gaussian_noise) gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch if keepmean: gaussian_noise -= gaussian_noise.mean() gaussian_noise += (attmean / (attmax - attmin)) return gaussian_noise def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): # TODO: Remove imgs from this function and only take it as input if debug is True """ Calculates the plausibility score based on the given inputs. Args: imgs (torch.Tensor): The input images. targets_out (torch.Tensor): The output targets. attr (torch.Tensor): The attribute tensor. debug (bool, optional): Whether to enable debug mode. Defaults to False. Returns: torch.Tensor: The plausibility score. """ # # if imgs is None: # # imgs = torch.zeros_like(attr) # # with torch.no_grad(): # target_inds = targets_out[:, 0].int() # xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] # num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) # # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) # xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() # co = xyxy_corners # if corners: # co = targets_out[:, 2:6].int() # coords_map = torch.zeros_like(attr, dtype=torch.bool) # # rows = np.arange(co.shape[0]) # x1, x2 = co[:,1], co[:,3] # y1, y2 = co[:,0], co[:,2] # for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop # coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True if torch.isnan(attr).any(): attr = torch.nan_to_num(attr, nan=0.0) coords_map = get_bbox_map(targets_out, attr) plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr))) if debug: for i in range(len(coords_map)): coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) test_bbox = torch.zeros_like(imgs[i]) test_bbox[coords_map3ch] = imgs[i][coords_map3ch] imshow(test_bbox, save_path='figs/test_bbox') if imgs is None: imgs = torch.zeros_like(attr) imshow(imgs[i], save_path='figs/im0') imshow(attr[i], save_path='figs/attr') # with torch.no_grad(): # # att_select = attr[coords_map] # att_select = attr * coords_map.to(torch.float32) # att_total = attr # IoU_num = torch.sum(att_select) # IoU_denom = torch.sum(att_total) # IoU_ = (IoU_num / IoU_denom) # plaus_score = IoU_ # # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr))) return plaus_score def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): # TODO: Remove imgs from this function and only take it as input if debug is True """ Calculates the plausibility score based on the given inputs. Args: imgs (torch.Tensor): The input images. targets_out (torch.Tensor): The output targets. attr (torch.Tensor): The attribute tensor. debug (bool, optional): Whether to enable debug mode. Defaults to False. Returns: torch.Tensor: The plausibility score. """ # if imgs is None: # imgs = torch.zeros_like(attr) # with torch.no_grad(): target_inds = targets_out[:, 0].int() xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() co = xyxy_corners if corners: co = targets_out[:, 2:6].int() coords_map = torch.zeros_like(attr, dtype=torch.bool) # rows = np.arange(co.shape[0]) x1, x2 = co[:,1], co[:,3] y1, y2 = co[:,0], co[:,2] for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True if torch.isnan(attr).any(): attr = torch.nan_to_num(attr, nan=0.0) if debug: for i in range(len(coords_map)): coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) test_bbox = torch.zeros_like(imgs[i]) test_bbox[coords_map3ch] = imgs[i][coords_map3ch] imshow(test_bbox, save_path='figs/test_bbox') imshow(imgs[i], save_path='figs/im0') imshow(attr[i], save_path='figs/attr') # att_select = attr[coords_map] # with torch.no_grad(): # IoU_num = (torch.sum(attr[coords_map])) # IoU_denom = torch.sum(attr) # IoU_ = (IoU_num / (IoU_denom)) # IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map]) co = (xyxy_batch * num_pixels).int() x1 = co[:,1] + 1 y1 = co[:,0] + 1 # with torch.no_grad(): attr_ = torch.sum(attr, 1, keepdim=True) corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device) for ic in range(co.shape[0]): attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]] attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:] attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]] attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:] x_0, y_0 = max_indices_2d(attr0[0]) x_1, y_1 = max_indices_2d(attr1[0]) x_2, y_2 = max_indices_2d(attr2[0]) x_3, y_3 = max_indices_2d(attr3[0]) y_1 += y1[ic] x_2 += x1[ic] x_3 += x1[ic] y_3 += y1[ic] max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2], torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3], torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2], torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]]) if corners_attr is None: corners_attr = max_corners else: corners_attr = torch.cat([corners_attr, max_corners], dim=0) # corners_attr[ic] = max_corners # corners_attr = attr[:,0,:4,0] corners_attr = corners_attr.view(-1, 4) # corners_attr = torch.stack(corners_attr, dim=0) IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU') plaus_score = IoU_.mean() return plaus_score def max_indices_2d(x_inp): # values, indices = x.reshape(x.size(0), -1).max(dim=-1) torch.max(x_inp,) index = torch.argmax(x_inp) x = index // x_inp.shape[1] y = index % x_inp.shape[1] # x, y = divmod(index.item(), x_inp.shape[1]) return torch.cat([x.unsqueeze(0), y.unsqueeze(0)]) def point_in_polygon(poly, grid): # t0 = time.time() num_points = poly.shape[0] j = num_points - 1 oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) for i in range(num_points): cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1]) cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1]) cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1]) oddNodes = oddNodes ^ (cond1 | cond2) & cond3 j = i # t1 = time.time() # print(f'point in polygon time: {t1-t0}') return oddNodes def point_in_polygon_gpu(poly, grid): num_points = poly.shape[0] i = torch.arange(num_points) j = (i - 1) % num_points # Expand dimensions # t0 = time.time() poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0]) # t1 = time.time() cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1]) cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1]) cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1]) # t2 = time.time() oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) cond = (cond1 | cond2) & cond3 # t3 = time.time() # efficiently perform xor using gpu and avoiding cpu as much as possible c = [] while len(cond) > 1: if len(cond) % 2 == 1: # odd number of elements c.append(cond[-1]) cond = cond[:-1] cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):]) for c_ in c: cond = torch.bitwise_xor(cond, c_) oddNodes = cond # t4 = time.time() # for c in cond: # oddNodes = oddNodes ^ c # print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}') # print(f'point in polygon time gpu: {t4-t0}') # oddNodes = oddNodes ^ (cond1 | cond2) & cond3 return oddNodes def bitmap_for_polygon(poly, h, w): y = torch.arange(h).to(poly.device).float() x = torch.arange(w).to(poly.device).float() grid_y, grid_x = torch.meshgrid(y, x) grid = torch.stack((grid_x, grid_y), dim=-1) bitmap = point_in_polygon(poly, grid) return bitmap.unsqueeze(0) def corners_coords(center_xywh): center_x, center_y, w, h = center_xywh x = center_x - w/2 y = center_y - h/2 return torch.tensor([x, y, x+w, y+h]) def corners_coords_batch(center_xywh): center_x, center_y = center_xywh[:,0], center_xywh[:,1] w, h = center_xywh[:,2], center_xywh[:,3] x = center_x - w/2 y = center_y - h/2 return torch.stack([x, y, x+w, y+h], dim=1) def normalize_batch(x): """ Normalize a batch of tensors along each channel. Args: x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). Returns: torch.Tensor: Normalized tensor of the same shape as the input. """ mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) for i in range(x.shape[0]): mins[i] = x[i].min() maxs[i] = x[i].max() x_ = (x - mins) / (maxs - mins) return x_ def normalize_batch_nonan(x): """ Normalize a batch of tensors along each channel. Args: x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). Returns: torch.Tensor: Normalized tensor of the same shape as the input. """ mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) x_ = torch.zeros_like(x) for i in range(x.shape[0]): if torch.all(x[i] == 0): x_[i] = x[i] else: mins[i] = x[i].min() maxs[i] = x[i].max() x_[i] = (x[i] - mins[i]) / (maxs[i] - mins[i]) return x_ def get_detections(model_clone, img): """ Get detections from a model given an input image and targets. Args: model (nn.Module): The model to use for detection. img (torch.Tensor): The input image tensor. Returns: torch.Tensor: The detected bounding boxes. """ model_clone.eval() # Set model to evaluation mode # Run inference with torch.no_grad(): det_out, out = model_clone(img) # model_.train() del img return det_out, out def get_labels(det_out, imgs, targets, opt): ###################### Get predicted labels ###################### nb, _, height, width = imgs.shape # batch size, channels, height, width targets_ = targets.clone() targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True) pred_labels = [] for si, pred in enumerate(o): labels = targets_[targets_[:, 0] == si, 1:] nl = len(labels) predn = pred.clone() # Get the indices that sort the values in column 5 in ascending order sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True) # Apply the sorting indices to the tensor sorted_pred = predn[sort_indices] # Remove predictions with less than 0.1 confidence n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1 sorted_pred = sorted_pred[:n_conf] new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1) preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh preds[:, 2:] /= gn.to(imgs.device) # from pixels pred_labels.append(preds) pred_labels = torch.cat(pred_labels, 0).to(imgs.device) return pred_labels ################################################################## from torchvision.utils import make_grid def get_center_coords(attr): img_tensor = img_tensor / img_tensor.max() # Define a brightness threshold threshold = 0.95 # Create a binary mask of the bright pixels mask = img_tensor > threshold # Get the coordinates of the bright pixels y_coords, x_coords = torch.where(mask) # Calculate the centroid of the bright pixels centroid_x = x_coords.float().mean().item() centroid_y = y_coords.float().mean().item() print(f'The central bright point is at ({centroid_x}, {centroid_y})') return def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False): """ Compute the distance grids from each pixel to the target coordinates. Args: attr (torch.Tensor): Attribution maps. targets (torch.Tensor): Target coordinates. focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5. debug (bool, optional): Whether to visualize debug information. Defaults to False. Returns: torch.Tensor: Distance grids. """ # Assign the height and width of the input tensor to variables height, width = attr.shape[-1], attr.shape[-2] # attr = torch.abs(attr) # Take absolute values of gradients # attr = normalize_batch(attr) # Normalize attribution maps per image in batch # Create a grid of indices xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device) idx_grid = torch.stack((xx, yy), dim=-1).float() # Expand the grid to match the batch size idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1) # Initialize a list to store the distance grids dist_grids_ = [[]] * attr.shape[0] # Loop over batches for j in range(attr.shape[0]): # Get the rows where the first column is the current unique value rows = targets[targets[:, 0] == j] if len(rows) != 0: # Create a tensor for the target coordinates xy = rows[:,2:4] # y, x # Flip the x and y coordinates and scale them to the image size xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True) # Compute the Euclidean distance from each pixel to the target coordinates dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1) # Pick the closest distance to any target for each pixel dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0) dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_ else: # Set grid to zero if no targets are present dist_grid = torch.zeros_like(attr[j]) dist_grids_[j] = dist_grid # Convert the list of distance grids to a tensor for faster computation dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff if torch.isnan(dist_grids).any(): dist_grids = torch.nan_to_num(dist_grids, nan=0.0) if debug: for i in range(len(dist_grids)): if ((i % 8) == 0): grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0) imshow(grid_show, save_path='figs/dist_grids') if imgs is None: imgs = torch.zeros_like(attr) imshow(imgs[i], save_path='figs/im0') img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75)) imshow(img_overlay, save_path='figs/dist_grid_overlay') weighted_attr = (dist_grids[i] * attr[i]) imshow(weighted_attr, save_path='figs/weighted_attr') imshow(attr[i], save_path='figs/attr') return dist_grids def attr_reg(attribution_map, distance_map): # dist_attr = distance_map * attribution_map dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3)) # del distance_map, attribution_map return dist_attr def get_bbox_map(targets_out, attr, corners=False): target_inds = targets_out[:, 0].int() xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() co = xyxy_corners if corners: co = targets_out[:, 2:6].int() coords_map = torch.zeros_like(attr, dtype=torch.bool) # rows = np.arange(co.shape[0]) x1, x2 = co[:,1], co[:,3] y1, y2 = co[:,0], co[:,2] for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True bbox_map = coords_map.to(torch.float32) return bbox_map ######################################## BCE ####################################### def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False): # if imgs is None: # imgs = torch.zeros_like(attribution_map) # Calculate Plausibility IoU with attribution maps # attribution_map.retains_grad = True if not only_loss: plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs) else: plaus_score = torch.tensor(0.0) # attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch # Calculate distance regularization distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff) # distance_map = torch.ones_like(attribution_map) if opt.dist_x_bbox: bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool) distance_map[bbox_map] = 0.0 # distance_map = distance_map * (1 - bbox_map) # Positive regularization term for incentivizing pixels near the target to have high attribution dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map)) # Negative regularization term for incentivizing pixels far from the target to have low attribution dist_attr_neg = attr_reg(attribution_map, distance_map) # Calculate plausibility regularization term # dist_reg = dist_attr_pos - dist_attr_neg dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map))) # dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))) # dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \ # torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \ # / 2.5 if opt.bbox_coeff != 0.0: bbox_map = get_bbox_map(targets, attribution_map) attr_bbox_pos = attr_reg(attribution_map, bbox_map) attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map)) bbox_reg = attr_bbox_pos - attr_bbox_neg # bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map)) else: bbox_reg = 0.0 bbox_map = get_bbox_map(targets, attribution_map) plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map))) # iou_loss = (1.0 - plaus_score) if not opt.dist_reg_only: dist_reg_loss = (((1.0 + dist_reg) / 2.0)) plaus_reg = (plaus_score * opt.iou_coeff) + \ (((dist_reg_loss * opt.dist_coeff) + \ (bbox_reg * opt.bbox_coeff))\ # ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\ # / (plaus_score) \ ) else: plaus_reg = (((1.0 + dist_reg) / 2.0)) # plaus_reg = dist_reg # Calculate plausibility loss plaus_loss = (1 - plaus_reg) * opt.pgt_coeff # plaus_loss = (plaus_reg) * opt.pgt_coeff if only_loss: return plaus_loss if not debug: return plaus_loss, (plaus_score, dist_reg, plaus_reg,) else: return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map #################################################################################### #### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS #### #################################################################################### def generate_vanilla_grad(model, input_tensor, loss_func = None, targets_list=None, targets=None, metric=None, out_num = 1, n_max_labels=3, norm=True, abs=True, grayscale=True, class_specific_attr = True, device='cpu'): """ Generate vanilla gradients for the given model and input tensor. Args: model (nn.Module): The model to generate gradients for. input_tensor (torch.Tensor): The input tensor for which gradients are computed. loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None. targets_list (list, optional): The list of target tensors. Defaults to None. metric (callable, optional): The metric function to evaluate the loss. Defaults to None. out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1. n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3. norm (bool, optional): Whether to normalize the attribution map. Defaults to True. abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True. grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True. class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True. device (str, optional): The device to use for computation. Defaults to 'cpu'. Returns: torch.Tensor: The generated vanilla gradients. """ # Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end train_mode = model.training if not train_mode: model.train() input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients model.zero_grad() # Zero gradients inpt = input_tensor # Forward pass train_out = model(inpt) # training outputs (no inference outputs in train mode) # train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities) # train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling) # train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence) if class_specific_attr: n_attr_list, index_classes = [], [] for i in range(len(input_tensor)): if len(targets_list[i]) > n_max_labels: targets_list[i] = targets_list[i][:n_max_labels] if targets_list[i].numel() != 0: # unique_classes = torch.unique(targets_list[i][:,1]) class_numbers = targets_list[i][:,1] index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers]) num_attrs = len(targets_list[i]) # index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes]) # num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i]) n_attr_list.append(num_attrs) else: index_classes.append([0, 1, 2, 3, 4]) n_attr_list.append(0) targets_list_filled = [targ.clone().detach() for targ in targets_list] labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))] max_labels = np.max(labels_len) max_index = np.argmax(labels_len) for i in range(len(targets_list)): # targets_list_filled[i] = targets_list[i] if len(targets_list_filled[i]) < max_labels: tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i])) targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0) else: targets_list_filled[i] = targets_list_filled[i].unsqueeze(0) for i in range(len(targets_list_filled)-1,-1,-1): if targets_list_filled[i].numel() == 0: targets_list_filled.pop(i) targets_list_filled = torch.cat(targets_list_filled) n_img_attrs = len(input_tensor) if class_specific_attr else 1 n_img_attrs = 1 if loss_func else n_img_attrs attrs_batch = [] for i_batch in range(n_img_attrs): if loss_func and class_specific_attr: i_batch = max_index # inpt = input_tensor[i_batch].unsqueeze(0) # ################################################################## # model.zero_grad() # Zero gradients # train_out = model(inpt) # training outputs (no inference outputs in train mode) # ################################################################## n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1 n_label_attrs = 1 if not class_specific_attr else n_label_attrs attrs_img = [] for i_attr in range(n_label_attrs): if loss_func is None: grad_wrt = train_out[out_num] if class_specific_attr: grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]] grad_wrt_outputs = torch.ones_like(grad_wrt) else: # if class_specific_attr: # targets = targets_list[:][i_attr] # n_targets = len(targets_list[i_batch]) if class_specific_attr: target_indiv = targets_list_filled[:,i_attr] # batch image input else: target_indiv = targets # target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input # target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time try: loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size except: target_indiv = target_indiv.to(device) inpt = inpt.to(device) for tro in train_out: tro = tro.to(device) print("Error in loss function, trying again with device specified") loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) grad_wrt = loss grad_wrt_outputs = None model.zero_grad() # Zero gradients gradients = torch.autograd.grad(grad_wrt, inpt, grad_outputs=grad_wrt_outputs, retain_graph=True, # create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly ) # Convert gradients to numpy array and back to ensure full separation from graph # attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy()) attribution_map = gradients[0]#.clone().detach() # without converting to numpy if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval attribution_map = torch.sum(attribution_map, 1, keepdim=True) if abs: attribution_map = torch.abs(attribution_map) # Take absolute values of gradients if norm: attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch attrs_img.append(attribution_map) if len(attrs_img) == 0: attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device)) else: attrs_batch.append(torch.stack(attrs_img).to(device)) # out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device) # out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch out_attr = attrs_batch # Set model back to original mode if not train_mode: model.eval() return out_attr class RVNonLinearFunc(torch.nn.Module): """ Custom Bayesian ReLU activation function for random variables. Attributes: None """ def __init__(self, func): super(RVNonLinearFunc, self).__init__() self.func = func def forward(self, mu_in, Sigma_in): """ Forward pass of the Bayesian ReLU activation function. Args: mu_in (torch.Tensor): A tensor of shape (batch_size, input_size), representing the mean input to the ReLU activation function. Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size), representing the covariance input to the ReLU activation function. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors, including the mean of the output and the covariance of the output. """ # Collect stats batch_size = mu_in.size(0) # Mean mu_out = self.func(mu_in) # Compute the derivative of the ReLU activation function with respect to the input mean gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1) # add an extra dimension to gradi at position 2 and 1 grad1 = gradi.unsqueeze(dim=2) grad2 = gradi.unsqueeze(dim=1) # compute the outer product of grad1 and grad2 outer_product = torch.bmm(grad1, grad2) # element-wise multiply Sigma_in with the outer product # and return the result Sigma_out = torch.mul(Sigma_in, outer_product) return mu_out, Sigma_out