From 38fa59edf2f706ae9fe0c013166b9d6556a685ae Mon Sep 17 00:00:00 2001 From: nielseni6 Date: Wed, 16 Oct 2024 19:26:41 -0400 Subject: [PATCH] PGT Training now functioning. Run run_pgt_train.py to train the model. --- run_attribution.py | 29 +- run_pgt_train.py | 47 ++ ultralytics/engine/model.py | 5 +- ultralytics/engine/pgt_trainer.py | 88 +- ultralytics/engine/pgt_validator.py | 345 ++++++++ ultralytics/models/yolo/detect/__init__.py | 3 +- ultralytics/models/yolo/detect/pgt_val.py | 300 +++++++ ultralytics/models/yolo/segment/__init__.py | 3 +- ultralytics/models/yolo/segment/pgt_train.py | 73 ++ ultralytics/models/yolov10/model.py | 2 +- ultralytics/models/yolov10/val.py | 23 +- ultralytics/nn/tasks.py | 6 +- ultralytics/utils/loss.py | 29 +- ultralytics/utils/plaus_functs.py | 818 +++++++++++++++++++ ultralytics/utils/plot_functs.py | 154 ++++ ultralytics/utils/plotting.py | 2 +- 16 files changed, 1895 insertions(+), 32 deletions(-) create mode 100644 run_pgt_train.py create mode 100644 ultralytics/engine/pgt_validator.py create mode 100644 ultralytics/models/yolo/detect/pgt_val.py create mode 100644 ultralytics/models/yolo/segment/pgt_train.py create mode 100644 ultralytics/utils/plaus_functs.py create mode 100644 ultralytics/utils/plot_functs.py diff --git a/run_attribution.py b/run_attribution.py index 8c87e592..3a603934 100644 --- a/run_attribution.py +++ b/run_attribution.py @@ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO # from ultralytics import BaseTrainer # from ultralytics.engine.trainer import BaseTrainer import os +from ultralytics.models.yolo.segment import PGTSegmentationTrainer + # Set CUDA device (only needed for multi-gpu machines) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "4" # model = YOLOv10() +# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt') # build from YAML and transfer weights + # model = YOLO() # If you want to finetune the model with pretrained weights, you could load the # pretrained weights like below # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') # or # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt -model = YOLOv10('yolov10n.pt') +model = YOLOv10('yolov10n.pt', task='segment') -model.train(data='coco.yaml', - trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) - # Add return_images as input parameter - epochs=500, batch=16, imgsz=640, - debug=True, # If debug = True, the attributions will be saved in the figures folder - ) +args = dict(model='yolov10n.pt', data='coco128-seg.yaml') +trainer = PGTSegmentationTrainer(overrides=args) +trainer.train( + # debug=True, + # args = dict(pgt_coeff=0.1), + ) + +# model.train( +# # data='coco.yaml', +# data='coco128-seg.yaml', +# trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) +# # Add return_images as input parameter +# epochs=500, batch=16, imgsz=640, +# debug=True, # If debug = True, the attributions will be saved in the figures folder +# # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml', +# # overrides=dict(task="segment"), +# ) # Save the trained model model.save('yolov10_coco_trained.pt') diff --git a/run_pgt_train.py b/run_pgt_train.py new file mode 100644 index 00000000..e308365d --- /dev/null +++ b/run_pgt_train.py @@ -0,0 +1,47 @@ +from ultralytics import YOLOv10, YOLO +# from ultralytics.engine.pgt_trainer import PGTTrainer +import os +from ultralytics.models.yolo.segment import PGTSegmentationTrainer +import argparse + + +def main(args): + # model = YOLOv10() + + # If you want to finetune the model with pretrained weights, you could load the + # pretrained weights like below + # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') + # or + # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt + model = YOLOv10('yolov10n.pt', task='segment') + + args = dict(model='yolov10n.pt', data='coco.yaml', + epochs=args.epochs, batch=args.batch_size, + # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process + ) + trainer = PGTSegmentationTrainer(overrides=args) + trainer.train( + # debug=True, + # args = dict(pgt_coeff=0.1), # Should add later to config + ) + + # Save the trained model + model.save('yolov10_coco_trained.pt') + + # Evaluate the model on the validation set + results = model.val(data='coco.yaml') + + # Print the evaluation results + print(results) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.') + parser.add_argument('--device', type=str, default='0', help='CUDA device number') + parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') + parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training') + args = parser.parse_args() + + # Set CUDA device (only needed for multi-gpu machines) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + main(args) \ No newline at end of file diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index be87559c..53cb1f40 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -656,7 +656,10 @@ class Model(nn.Module): pass self.trainer.hub_session = self.session # attach optional HUB session - self.trainer.train(debug=debug) + if debug: + self.trainer.train(debug=debug) + else: + self.trainer.train() # Update model and cfg after training if RANK in (-1, 0): ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py index 3d99ad41..6b7c973e 100644 --- a/ultralytics/engine/pgt_trainer.py +++ b/ultralytics/engine/pgt_trainer.py @@ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import ( strip_optimizer, ) - +from ultralytics.utils.loss import v8DetectionLoss +from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn +import matplotlib.path as matplotlib_path class PGTTrainer: """ BaseTrainer. @@ -159,6 +161,7 @@ class PGTTrainer: self.loss_names = ["Loss"] self.csv = self.save_dir / "results.csv" self.plot_idx = [0, 1, 2] + self.num = int(time.time()) # Callbacks self.callbacks = _callbacks or callbacks.get_default_callbacks() @@ -328,7 +331,7 @@ class PGTTrainer: if world_size > 1: self._setup_ddp(world_size) self._setup_train(world_size) - + nb = len(self.train_loader) # number of batches nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations last_opt_step = -1 @@ -382,37 +385,88 @@ class PGTTrainer: with torch.cuda.amp.autocast(self.amp): batch = self.preprocess_batch(batch) (self.loss, self.loss_items), images = self.model(batch, return_images=True) + + # smask = get_dist_reg(images, batch['masks']) + + # grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0] + # grad = torch.abs(grad) + + # self.args.pgt_coeff = 1.1 + # plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff) + # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0))) + # self.loss += plaus_loss + + debug_ = debug + if debug_ and (i % 25 == 0): + debug_ = False + # Create a tensor of zeros with the same size as images + mask = torch.zeros_like(images, dtype=torch.float32) + smask = get_dist_reg(images, batch['masks']) + grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0] + grad = torch.abs(grad) + + batch_size = images.shape[0] + imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device) + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - if debug and (i % 250): - grad = torch.autograd.grad(self.loss, images, create_graph=True)[0] + # Iterate over each bounding box and set the corresponding pixels to 1 + for irx, bboxes in enumerate(gt_bboxes): + for idx in range(len(bboxes)): + x1, y1, x2, y2 = bboxes[idx] + x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) + mask[irx, :, y1:y2, x1:x2] = 1.0 + + save_imgs = True + if save_imgs: # Convert tensors to numpy arrays images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) + mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1) + seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1) # Normalize grad for visualization grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) - + for ix in range(images_np.shape[0]): - fig, ax = plt.subplots(1, 3, figsize=(15, 5)) - ax[0].imshow(images_np[i]) + fig, ax = plt.subplots(1, 6, figsize=(30, 5)) + ax[0].imshow(images_np[ix]) ax[0].set_title('Image') - ax[1].imshow(grad_np[i], cmap='jet') + ax[1].imshow(grad_np[ix], cmap='jet') ax[1].set_title('Gradient') - ax[2].imshow(images_np[i]) - ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5) + ax[2].imshow(images_np[ix]) + ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5) ax[2].set_title('Overlay') + ax[3].imshow(mask_np[ix], cmap='gray') + ax[3].set_title('Mask') + ax[4].imshow(seg_mask_np[ix], cmap='gray') + ax[4].set_title('Segmentation Mask') + + # Plot image with bounding boxes + ax[5].imshow(images_np[ix]) + for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]): + x1, y1, x2, y2 = bbox + x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) + rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2) + ax[5].add_patch(rect) + ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5)) + ax[5].set_title('Bounding Boxes') - save_dir_attr = "figures/attributions" + save_dir_attr = f"figures/attributions/run{self.num}" if not os.path.exists(save_dir_attr): os.makedirs(save_dir_attr) plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') plt.close(fig) - - if RANK != -1: - self.loss *= world_size - self.tloss = ( - (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items - ) + images = images.detach() + + + if RANK != -1: + self.loss *= world_size + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) # Backward self.scaler.scale(self.loss).backward() diff --git a/ultralytics/engine/pgt_validator.py b/ultralytics/engine/pgt_validator.py new file mode 100644 index 00000000..ce411ba9 --- /dev/null +++ b/ultralytics/engine/pgt_validator.py @@ -0,0 +1,345 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Check a model's accuracy on a test or val split of a dataset. + +Usage: + $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640 + +Usage - formats: + $ yolo mode=val model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n_ncnn_model # NCNN +""" + +import json +import time +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis +from ultralytics.utils.checks import check_imgsz +from ultralytics.utils.ops import Profile +from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode + + +class PGTValidator: + """ + BaseValidator. + + A base class for creating validators. + + Attributes: + args (SimpleNamespace): Configuration for the validator. + dataloader (DataLoader): Dataloader to use for validation. + pbar (tqdm): Progress bar to update during validation. + model (nn.Module): Model to validate. + data (dict): Data dictionary. + device (torch.device): Device to use for validation. + batch_i (int): Current batch index. + training (bool): Whether the model is in training mode. + names (dict): Class names. + seen: Records the number of images seen so far during validation. + stats: Placeholder for statistics during validation. + confusion_matrix: Placeholder for a confusion matrix. + nc: Number of classes. + iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05. + jdict (dict): Dictionary to store JSON validation results. + speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective + batch processing times in milliseconds. + save_dir (Path): Directory to save results. + plots (dict): Dictionary to store plots for visualization. + callbacks (dict): Dictionary to store various callback functions. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initializes a BaseValidator instance. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (dict): Dictionary to store various callback functions. + """ + self.args = get_cfg(overrides=args) + self.dataloader = dataloader + self.pbar = pbar + self.stride = None + self.data = None + self.device = None + self.batch_i = None + self.training = True + self.names = None + self.seen = None + self.stats = None + self.confusion_matrix = None + self.nc = None + self.iouv = None + self.jdict = None + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + self.save_dir = save_dir or get_save_dir(self.args) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + if self.args.conf is None: + self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) + + self.plots = {} + self.callbacks = _callbacks or callbacks.get_default_callbacks() + + # @smart_inference_mode() + def __call__(self, trainer=None, model=None): + """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer + gets priority). + """ + self.training = trainer is not None + augment = self.args.augment and (not self.training) + if self.training: + self.device = trainer.device + self.data = trainer.data + # self.args.half = self.device.type != "cpu" # force FP16 val during training + model = trainer.ema.ema or trainer.model + model = model.half() if self.args.half else model.float() + # self.model = model + self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) + self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1) + model.eval() + else: + callbacks.add_integration_callbacks(self) + model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + ) + # self.model = model + self.device = model.device # update device + self.args.half = model.fp16 # update half + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_imgsz(self.args.imgsz, stride=stride) + if engine: + self.args.batch = model.batch_size + elif not pt and not jit: + self.args.batch = 1 # export.py models default to batch-size 1 + LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") + + if str(self.args.data).split(".")[-1] in ("yaml", "yml"): + self.data = check_det_dataset(self.args.data) + elif self.args.task == "classify": + self.data = check_cls_dataset(self.args.data, split=self.args.split) + else: + raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) + + if self.device.type in ("cpu", "mps"): + self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading + if not pt: + self.args.rect = False + self.stride = model.stride # used in get_dataloader() for padding + self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) + + model.eval() + model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup + + self.run_callbacks("on_val_start") + dt = ( + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + ) + bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) + self.init_metrics(de_parallel(model)) + self.jdict = [] # empty before each val + for batch_i, batch in enumerate(bar): + self.run_callbacks("on_val_batch_start") + self.batch_i = batch_i + # Preprocess + with dt[0]: + batch = self.preprocess(batch) + + # Inference + with dt[1]: + preds = model(batch["img"].requires_grad_(True), augment=augment) + + # Loss + with dt[2]: + if self.training: + self.loss += model.loss(batch, preds)[1] + + # Postprocess + with dt[3]: + preds = self.postprocess(preds) + + self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) + + self.run_callbacks("on_val_batch_end") + stats = self.get_stats() + self.check_stats(stats) + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) + self.finalize_metrics() + if not (self.args.save_json and self.is_coco and len(self.jdict)): + self.print_results() + self.run_callbacks("on_val_end") + if self.training: + model.float() + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + stats['fitness'] = stats['metrics/mAP50-95(B)'] + results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} + return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats + else: + LOGGER.info( + "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" + % tuple(self.speed.values()) + ) + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + if self.args.plots or self.args.save_json: + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + return stats + + def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): + """ + Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. + + Args: + pred_classes (torch.Tensor): Predicted class indices of shape(N,). + true_classes (torch.Tensor): Target class indices of shape(M,). + iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth + use_scipy (bool): Whether to use scipy for matching (more precise). + + Returns: + (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. + """ + # Dx10 matrix, where D - detections, 10 - IoU thresholds + correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + # LxD matrix where L - labels (rows), D - detections (columns) + correct_class = true_classes[:, None] == pred_classes + iou = iou * correct_class # zero out the wrong classes + iou = iou.cpu().numpy() + for i, threshold in enumerate(self.iouv.cpu().tolist()): + if use_scipy: + # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 + import scipy # scope import to avoid importing for all commands + + cost_matrix = iou * (iou >= threshold) + if cost_matrix.any(): + labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) + valid = cost_matrix[labels_idx, detections_idx] > 0 + if valid.any(): + correct[detections_idx[valid], i] = True + else: + matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match + matches = np.array(matches).T + if matches.shape[0]: + if matches.shape[0] > 1: + matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Runs all callbacks associated with a specified event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def get_dataloader(self, dataset_path, batch_size): + """Get data loader from dataset path and batch size.""" + raise NotImplementedError("get_dataloader function not implemented for this validator") + + def build_dataset(self, img_path): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in validator") + + def preprocess(self, batch): + """Preprocesses an input batch.""" + return batch + + def postprocess(self, preds): + """Describes and summarizes the purpose of 'postprocess()' but no details mentioned.""" + return preds + + def init_metrics(self, model): + """Initialize performance metrics for the YOLO model.""" + pass + + def update_metrics(self, preds, batch): + """Updates metrics based on predictions and batch.""" + pass + + def finalize_metrics(self, *args, **kwargs): + """Finalizes and returns all metrics.""" + pass + + def get_stats(self): + """Returns statistics about the model's performance.""" + return {} + + def check_stats(self, stats): + """Checks statistics.""" + pass + + def print_results(self): + """Prints the results of the model's predictions.""" + pass + + def get_desc(self): + """Get description of the YOLO model.""" + pass + + @property + def metric_keys(self): + """Returns the metric keys used in YOLO training/validation.""" + return [] + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks)""" + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} + + # TODO: may need to put these following functions into callback + def plot_val_samples(self, batch, ni): + """Plots validation samples during training.""" + pass + + def plot_predictions(self, batch, preds, ni): + """Plots YOLO model predictions on batch images.""" + pass + + def pred_to_json(self, preds, batch): + """Convert predictions to JSON format.""" + pass + + def eval_json(self, stats): + """Evaluate and return JSON format of prediction statistics.""" + pass diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py index b4cb6dce..656966f5 100644 --- a/ultralytics/models/yolo/detect/__init__.py +++ b/ultralytics/models/yolo/detect/__init__.py @@ -4,5 +4,6 @@ from .predict import DetectionPredictor from .pgt_train import PGTDetectionTrainer from .train import DetectionTrainer from .val import DetectionValidator +from .pgt_val import PGTDetectionValidator -__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer" +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator" diff --git a/ultralytics/models/yolo/detect/pgt_val.py b/ultralytics/models/yolo/detect/pgt_val.py new file mode 100644 index 00000000..04138ff7 --- /dev/null +++ b/ultralytics/models/yolo/detect/pgt_val.py @@ -0,0 +1,300 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import os +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.data import build_dataloader, build_yolo_dataset, converter +from ultralytics.engine.validator import BaseValidator +from ultralytics.engine.pgt_validator import PGTValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class PGTDetectionValidator(PGTValidator): + """ + A class extending the BaseValidator class for validation based on a detection model. + + Example: + ```python + from ultralytics.models.yolo.detect import DetectionValidator + + args = dict(model='yolov8n.pt', data='coco8.yaml') + validator = DetectionValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize detection model with necessary variables and settings.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.nt_per_class = None + self.is_coco = False + self.class_map = None + self.args.task = "detect" + self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.lb = [] # for autolabelling + + def preprocess(self, batch): + """Preprocesses batch of images for YOLO training.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + for k in ["batch_idx", "cls", "bboxes"]: + batch[k] = batch[k].to(self.device) + + if self.args.save_hybrid: + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) + self.lb = ( + [ + torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) + for i in range(nb) + ] + if self.args.save_hybrid + else [] + ) # for autolabelling + + return batch + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + val = self.data.get(self.args.split, "") # validation path + self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO + self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + self.args.save_json |= self.is_coco # run on final val if training COCO + self.names = model.names + self.nc = len(model.names) + self.metrics.names = self.names + self.metrics.plot = self.args.plots + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) + self.seen = 0 + self.jdict = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[]) + + def get_desc(self): + """Return a formatted string summarizing class metrics of YOLO model.""" + return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + ) + + def _prepare_batch(self, si, batch): + """Prepares a batch of images and annotations for validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels + return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad) + + def _prepare_pred(self, pred, pbatch): + """Prepares a batch of images and annotations for validation.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] + ) # native-space pred + return predn + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' + self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) + + def finalize_metrics(self, *args, **kwargs): + """Set final values for metrics speed and confusion matrix.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def get_stats(self): + """Returns metrics statistics and results dictionary.""" + stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy + if len(stats) and stats["tp"].any(): + self.metrics.process(**stats) + self.nt_per_class = np.bincount( + stats["target_cls"].astype(int), minlength=self.nc + ) # number of targets per class + return self.metrics.results_dict + + def print_results(self): + """Prints training/validation set metrics per class.""" + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + if self.nt_per_class.sum() == 0: + LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") + + # Print results per class + if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) + + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Return correct prediction matrix. + + Args: + detections (torch.Tensor): Tensor of shape [N, 6] representing detections. + Each detection is of the format: x1, y1, x2, y2, conf, class. + labels (torch.Tensor): Tensor of shape [M, 5] representing labels. + Each label is of the format: class, x1, y1, x2, y2. + + Returns: + (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels. + """ + iou = box_iou(gt_bboxes, detections[:, :4]) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride) + + def get_dataloader(self, dataset_path, batch_size): + """Construct and return dataloader.""" + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") + return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader + + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls in predn.tolist(): + xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") + + def pred_to_json(self, predn, filename): + """Serialize YOLO predictions to COCO json format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """Evaluates YOLO output in JSON format and returns performance statistics.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + eval = COCOeval(anno, pred, "bbox") + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py index ec1ac799..31e23957 100644 --- a/ultralytics/models/yolo/segment/__init__.py +++ b/ultralytics/models/yolo/segment/__init__.py @@ -3,5 +3,6 @@ from .predict import SegmentationPredictor from .train import SegmentationTrainer from .val import SegmentationValidator +from .pgt_train import PGTSegmentationTrainer -__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" +__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer" diff --git a/ultralytics/models/yolo/segment/pgt_train.py b/ultralytics/models/yolo/segment/pgt_train.py new file mode 100644 index 00000000..7e7712b9 --- /dev/null +++ b/ultralytics/models/yolo/segment/pgt_train.py @@ -0,0 +1,73 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import SegmentationModel, DetectionModel +from ultralytics.utils import DEFAULT_CFG, RANK +from ultralytics.utils.plotting import plot_images, plot_results +from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel +from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator + +class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on a segmentation model. + + Example: + ```python + from ultralytics.models.yolo.segment import SegmentationTrainer + + args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3) + trainer = SegmentationTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a SegmentationTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "segment" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return SegmentationModel initialized with specified config and weights.""" + if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']: + model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + else: + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of SegmentationValidator for validation of YOLO model.""" + + if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']: + self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss", + return YOLOv10PGTDetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + else: + self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" + return yolo.segment.SegmentationValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def plot_training_samples(self, batch, ni): + """Creates a plot of training sample images with labels and box coordinates.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots training/val metrics.""" + plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py index 6e29c3d0..0c8bcb9a 100644 --- a/ultralytics/models/yolov10/model.py +++ b/ultralytics/models/yolov10/model.py @@ -1,5 +1,5 @@ from ultralytics.engine.model import Model -from ultralytics.nn.tasks import YOLOv10DetectionModel +from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel from .val import YOLOv10DetectionValidator from .predict import YOLOv10DetectionPredictor from .train import YOLOv10DetectionTrainer diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py index 19a019c8..a96483ca 100644 --- a/ultralytics/models/yolov10/val.py +++ b/ultralytics/models/yolov10/val.py @@ -1,4 +1,4 @@ -from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator from ultralytics.utils import ops import torch @@ -7,6 +7,27 @@ class YOLOv10DetectionValidator(DetectionValidator): super().__init__(*args, **kwargs) self.args.save_json |= self.is_coco + def postprocess(self, preds): + if isinstance(preds, dict): + preds = preds["one2one"] + + if isinstance(preds, (list, tuple)): + preds = preds[0] + + # Acknowledgement: Thanks to sanha9999 in #190 and #181! + if preds.shape[-1] == 6: + return preds + else: + preds = preds.transpose(-1, -2) + boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc) + bboxes = ops.xywh2xyxy(boxes) + return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + +class YOLOv10PGTDetectionValidator(PGTDetectionValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.args.save_json |= self.is_coco + def postprocess(self, preds): if isinstance(preds, dict): preds = preds["one2one"] diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 00fbadff..48242826 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -57,7 +57,7 @@ from ultralytics.nn.modules import ( ) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml -from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss +from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss from ultralytics.utils.plotting import feature_visualization from ultralytics.utils.torch_utils import ( fuse_conv_and_bn, @@ -651,6 +651,10 @@ class WorldModel(DetectionModel): class YOLOv10DetectionModel(DetectionModel): def init_criterion(self): return v10DetectLoss(self) + +class YOLOv10PGTDetectionModel(DetectionModel): + def init_criterion(self): + return v10PGTDetectLoss(self) class Ensemble(nn.ModuleList): """Ensemble of models.""" diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index d0ca9c39..e10b949f 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors from .metrics import bbox_iou, probiou from .tal import bbox2dist - +from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn class VarifocalLoss(nn.Module): """ @@ -725,3 +725,30 @@ class v10DetectLoss: one2one = preds["one2one"] loss_one2one = self.one2one(one2one, batch) return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) + +class v10PGTDetectLoss: + def __init__(self, model): + self.one2many = v8DetectionLoss(model, tal_topk=10) + self.one2one = v8DetectionLoss(model, tal_topk=1) + + def __call__(self, preds, batch): + batch['img'] = batch['img'].requires_grad_(True) + one2many = preds["one2many"] + loss_one2many = self.one2many(one2many, batch) + one2one = preds["one2one"] + loss_one2one = self.one2one(one2one, batch) + + loss = loss_one2many[0] + loss_one2one[0] + + smask = get_dist_reg(batch['img'], batch['masks']) + + grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0] + grad = torch.abs(grad) + + pgt_coeff = 3.0 + plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff) + # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0))) + loss += plaus_loss + + return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0))) + \ No newline at end of file diff --git a/ultralytics/utils/plaus_functs.py b/ultralytics/utils/plaus_functs.py new file mode 100644 index 00000000..3c5b689a --- /dev/null +++ b/ultralytics/utils/plaus_functs.py @@ -0,0 +1,818 @@ +import torch +import numpy as np +# from plot_functs import * +from .plot_functs import normalize_tensor, overlay_mask, imshow +import math +import time +import matplotlib.path as mplPath +from matplotlib.path import Path +# from utils.general import non_max_suppression, xyxy2xywh, scale_coords +from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression +from .metrics import bbox_iou +import torchvision.transforms as T + +def plaus_loss_fn(grad, smask, pgt_coeff): + ################## Compute the PGT Loss ################## + # Positive regularization term for incentivizing pixels near the target to have high attribution + dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask + # Negative regularization term for incentivizing pixels far from the target to have low attribution + dist_attr_neg = attr_reg(grad, smask) + # Calculate plausibility regularization term + # dist_reg = dist_attr_pos - dist_attr_neg + dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad))) + plaus_reg = (((1.0 + dist_reg) / 2.0)) + # Calculate plausibility loss + plaus_loss = (1 - plaus_reg) * pgt_coeff + return plaus_loss + +def get_dist_reg(images, seg_mask): + seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device) + seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1) + seg_mask[seg_mask > 0] = 1.0 + + smask = torch.zeros_like(seg_mask) + sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)] + for k_it, sigma in enumerate(sigmas): + # Apply Gaussian blur to the mask + kernel_size = int(sigma + 50) + if kernel_size % 2 == 0: + kernel_size += 1 + seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask) + if torch.max(seg_mask1) > 1.0: + seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min()) + smask = torch.max(smask, seg_mask1) + return smask + +def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False): + """ + Compute the gradient of an image with respect to a given tensor. + + Args: + img (torch.Tensor): The input image tensor. + grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed. + norm (bool, optional): Whether to normalize the gradient. Defaults to True. + absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. + grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True. + keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False. + + Returns: + torch.Tensor: The computed attribution map. + + """ + if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])): + grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True) + else: + grad_wrt_outputs = None + attribution_map = torch.autograd.grad(grad_wrt, img, + grad_outputs=grad_wrt_outputs, + create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly + )[0] + if absolute: + attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients + if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval + attribution_map = torch.sum(attribution_map, 1, keepdim=True) + if norm: + if keepmean: + attmean = torch.mean(attribution_map) + attmin = torch.min(attribution_map) + attmax = torch.max(attribution_map) + attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch + if keepmean: + attribution_map -= attribution_map.mean() + attribution_map += (attmean / (attmax - attmin)) + + return attribution_map + +def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False): + """ + Generate Gaussian noise based on the input image. + + Args: + img (torch.Tensor): Input image. + grad_wrt: Gradient with respect to the input image. + norm (bool, optional): Whether to normalize the generated noise. Defaults to True. + absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. + grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True. + keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False. + + Returns: + torch.Tensor: Generated Gaussian noise. + """ + + gaussian_noise = torch.randn_like(img) + + if absolute: + gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients + if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval + gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True) + if norm: + if keepmean: + attmean = torch.mean(gaussian_noise) + attmin = torch.min(gaussian_noise) + attmax = torch.max(gaussian_noise) + gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch + if keepmean: + gaussian_noise -= gaussian_noise.mean() + gaussian_noise += (attmean / (attmax - attmin)) + + return gaussian_noise + + +def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): + # TODO: Remove imgs from this function and only take it as input if debug is True + """ + Calculates the plausibility score based on the given inputs. + + Args: + imgs (torch.Tensor): The input images. + targets_out (torch.Tensor): The output targets. + attr (torch.Tensor): The attribute tensor. + debug (bool, optional): Whether to enable debug mode. Defaults to False. + + Returns: + torch.Tensor: The plausibility score. + """ + # # if imgs is None: + # # imgs = torch.zeros_like(attr) + # # with torch.no_grad(): + # target_inds = targets_out[:, 0].int() + # xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] + # num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) + # # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) + # xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() + # co = xyxy_corners + # if corners: + # co = targets_out[:, 2:6].int() + # coords_map = torch.zeros_like(attr, dtype=torch.bool) + # # rows = np.arange(co.shape[0]) + # x1, x2 = co[:,1], co[:,3] + # y1, y2 = co[:,0], co[:,2] + + # for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop + # coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True + + if torch.isnan(attr).any(): + attr = torch.nan_to_num(attr, nan=0.0) + + coords_map = get_bbox_map(targets_out, attr) + plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr))) + + if debug: + for i in range(len(coords_map)): + coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) + test_bbox = torch.zeros_like(imgs[i]) + test_bbox[coords_map3ch] = imgs[i][coords_map3ch] + imshow(test_bbox, save_path='figs/test_bbox') + if imgs is None: + imgs = torch.zeros_like(attr) + imshow(imgs[i], save_path='figs/im0') + imshow(attr[i], save_path='figs/attr') + + # with torch.no_grad(): + # # att_select = attr[coords_map] + # att_select = attr * coords_map.to(torch.float32) + # att_total = attr + + # IoU_num = torch.sum(att_select) + # IoU_denom = torch.sum(att_total) + + # IoU_ = (IoU_num / IoU_denom) + # plaus_score = IoU_ + + # # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr))) + + return plaus_score + +def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): + # TODO: Remove imgs from this function and only take it as input if debug is True + """ + Calculates the plausibility score based on the given inputs. + + Args: + imgs (torch.Tensor): The input images. + targets_out (torch.Tensor): The output targets. + attr (torch.Tensor): The attribute tensor. + debug (bool, optional): Whether to enable debug mode. Defaults to False. + + Returns: + torch.Tensor: The plausibility score. + """ + # if imgs is None: + # imgs = torch.zeros_like(attr) + # with torch.no_grad(): + target_inds = targets_out[:, 0].int() + xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] + num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) + # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) + xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() + co = xyxy_corners + if corners: + co = targets_out[:, 2:6].int() + coords_map = torch.zeros_like(attr, dtype=torch.bool) + # rows = np.arange(co.shape[0]) + x1, x2 = co[:,1], co[:,3] + y1, y2 = co[:,0], co[:,2] + + for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop + coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True + + if torch.isnan(attr).any(): + attr = torch.nan_to_num(attr, nan=0.0) + if debug: + for i in range(len(coords_map)): + coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) + test_bbox = torch.zeros_like(imgs[i]) + test_bbox[coords_map3ch] = imgs[i][coords_map3ch] + imshow(test_bbox, save_path='figs/test_bbox') + imshow(imgs[i], save_path='figs/im0') + imshow(attr[i], save_path='figs/attr') + + # att_select = attr[coords_map] + # with torch.no_grad(): + # IoU_num = (torch.sum(attr[coords_map])) + # IoU_denom = torch.sum(attr) + # IoU_ = (IoU_num / (IoU_denom)) + + # IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map]) + co = (xyxy_batch * num_pixels).int() + x1 = co[:,1] + 1 + y1 = co[:,0] + 1 + # with torch.no_grad(): + attr_ = torch.sum(attr, 1, keepdim=True) + corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device) + for ic in range(co.shape[0]): + attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]] + attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:] + attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]] + attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:] + + x_0, y_0 = max_indices_2d(attr0[0]) + x_1, y_1 = max_indices_2d(attr1[0]) + x_2, y_2 = max_indices_2d(attr2[0]) + x_3, y_3 = max_indices_2d(attr3[0]) + + y_1 += y1[ic] + x_2 += x1[ic] + x_3 += x1[ic] + y_3 += y1[ic] + + max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2], + torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3], + torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2], + torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]]) + if corners_attr is None: + corners_attr = max_corners + else: + corners_attr = torch.cat([corners_attr, max_corners], dim=0) + # corners_attr[ic] = max_corners + # corners_attr = attr[:,0,:4,0] + corners_attr = corners_attr.view(-1, 4) + # corners_attr = torch.stack(corners_attr, dim=0) + IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU') + plaus_score = IoU_.mean() + + return plaus_score + +def max_indices_2d(x_inp): + # values, indices = x.reshape(x.size(0), -1).max(dim=-1) + torch.max(x_inp,) + index = torch.argmax(x_inp) + x = index // x_inp.shape[1] + y = index % x_inp.shape[1] + # x, y = divmod(index.item(), x_inp.shape[1]) + + return torch.cat([x.unsqueeze(0), y.unsqueeze(0)]) + + +def point_in_polygon(poly, grid): + # t0 = time.time() + num_points = poly.shape[0] + j = num_points - 1 + oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) + for i in range(num_points): + cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1]) + cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1]) + cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1]) + oddNodes = oddNodes ^ (cond1 | cond2) & cond3 + j = i + # t1 = time.time() + # print(f'point in polygon time: {t1-t0}') + return oddNodes + +def point_in_polygon_gpu(poly, grid): + num_points = poly.shape[0] + i = torch.arange(num_points) + j = (i - 1) % num_points + # Expand dimensions + # t0 = time.time() + poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0]) + # t1 = time.time() + cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1]) + cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1]) + cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1]) + # t2 = time.time() + oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) + cond = (cond1 | cond2) & cond3 + # t3 = time.time() + # efficiently perform xor using gpu and avoiding cpu as much as possible + c = [] + while len(cond) > 1: + if len(cond) % 2 == 1: # odd number of elements + c.append(cond[-1]) + cond = cond[:-1] + cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):]) + for c_ in c: + cond = torch.bitwise_xor(cond, c_) + oddNodes = cond + # t4 = time.time() + # for c in cond: + # oddNodes = oddNodes ^ c + # print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}') + # print(f'point in polygon time gpu: {t4-t0}') + # oddNodes = oddNodes ^ (cond1 | cond2) & cond3 + return oddNodes + + +def bitmap_for_polygon(poly, h, w): + y = torch.arange(h).to(poly.device).float() + x = torch.arange(w).to(poly.device).float() + grid_y, grid_x = torch.meshgrid(y, x) + grid = torch.stack((grid_x, grid_y), dim=-1) + bitmap = point_in_polygon(poly, grid) + return bitmap.unsqueeze(0) + + +def corners_coords(center_xywh): + center_x, center_y, w, h = center_xywh + x = center_x - w/2 + y = center_y - h/2 + return torch.tensor([x, y, x+w, y+h]) + +def corners_coords_batch(center_xywh): + center_x, center_y = center_xywh[:,0], center_xywh[:,1] + w, h = center_xywh[:,2], center_xywh[:,3] + x = center_x - w/2 + y = center_y - h/2 + return torch.stack([x, y, x+w, y+h], dim=1) + +def normalize_batch(x): + """ + Normalize a batch of tensors along each channel. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor: Normalized tensor of the same shape as the input. + """ + mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) + maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) + for i in range(x.shape[0]): + mins[i] = x[i].min() + maxs[i] = x[i].max() + x_ = (x - mins) / (maxs - mins) + + return x_ + +def get_detections(model_clone, img): + """ + Get detections from a model given an input image and targets. + + Args: + model (nn.Module): The model to use for detection. + img (torch.Tensor): The input image tensor. + + Returns: + torch.Tensor: The detected bounding boxes. + """ + model_clone.eval() # Set model to evaluation mode + # Run inference + with torch.no_grad(): + det_out, out = model_clone(img) + + # model_.train() + del img + + return det_out, out + +def get_labels(det_out, imgs, targets, opt): + ###################### Get predicted labels ###################### + nb, _, height, width = imgs.shape # batch size, channels, height, width + targets_ = targets.clone() + targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels + lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling + o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True) + pred_labels = [] + for si, pred in enumerate(o): + labels = targets_[targets_[:, 0] == si, 1:] + nl = len(labels) + predn = pred.clone() + # Get the indices that sort the values in column 5 in ascending order + sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True) + # Apply the sorting indices to the tensor + sorted_pred = predn[sort_indices] + # Remove predictions with less than 0.1 confidence + n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1 + sorted_pred = sorted_pred[:n_conf] + new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si + preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1) + preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh + gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh + preds[:, 2:] /= gn.to(imgs.device) # from pixels + pred_labels.append(preds) + pred_labels = torch.cat(pred_labels, 0).to(imgs.device) + + return pred_labels + ################################################################## + +from torchvision.utils import make_grid + +def get_center_coords(attr): + img_tensor = img_tensor / img_tensor.max() + + # Define a brightness threshold + threshold = 0.95 + + # Create a binary mask of the bright pixels + mask = img_tensor > threshold + + # Get the coordinates of the bright pixels + y_coords, x_coords = torch.where(mask) + + # Calculate the centroid of the bright pixels + centroid_x = x_coords.float().mean().item() + centroid_y = y_coords.float().mean().item() + + print(f'The central bright point is at ({centroid_x}, {centroid_y})') + + return + + +def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False): + """ + Compute the distance grids from each pixel to the target coordinates. + + Args: + attr (torch.Tensor): Attribution maps. + targets (torch.Tensor): Target coordinates. + focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5. + debug (bool, optional): Whether to visualize debug information. Defaults to False. + + Returns: + torch.Tensor: Distance grids. + """ + + # Assign the height and width of the input tensor to variables + height, width = attr.shape[-1], attr.shape[-2] + + # attr = torch.abs(attr) # Take absolute values of gradients + # attr = normalize_batch(attr) # Normalize attribution maps per image in batch + + # Create a grid of indices + xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device) + idx_grid = torch.stack((xx, yy), dim=-1).float() + + # Expand the grid to match the batch size + idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1) + + # Initialize a list to store the distance grids + dist_grids_ = [[]] * attr.shape[0] + + # Loop over batches + for j in range(attr.shape[0]): + # Get the rows where the first column is the current unique value + rows = targets[targets[:, 0] == j] + + if len(rows) != 0: + # Create a tensor for the target coordinates + xy = rows[:,2:4] # y, x + # Flip the x and y coordinates and scale them to the image size + xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y + xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True) + + # Compute the Euclidean distance from each pixel to the target coordinates + dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1) + + # Pick the closest distance to any target for each pixel + dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0) + dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_ + else: + # Set grid to zero if no targets are present + dist_grid = torch.zeros_like(attr[j]) + + dist_grids_[j] = dist_grid + # Convert the list of distance grids to a tensor for faster computation + dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff + if torch.isnan(dist_grids).any(): + dist_grids = torch.nan_to_num(dist_grids, nan=0.0) + + if debug: + for i in range(len(dist_grids)): + if ((i % 8) == 0): + grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0) + imshow(grid_show, save_path='figs/dist_grids') + if imgs is None: + imgs = torch.zeros_like(attr) + imshow(imgs[i], save_path='figs/im0') + img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75)) + imshow(img_overlay, save_path='figs/dist_grid_overlay') + weighted_attr = (dist_grids[i] * attr[i]) + imshow(weighted_attr, save_path='figs/weighted_attr') + imshow(attr[i], save_path='figs/attr') + + return dist_grids + +def attr_reg(attribution_map, distance_map): + + # dist_attr = distance_map * attribution_map + dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3)) + # del distance_map, attribution_map + return dist_attr + +def get_bbox_map(targets_out, attr, corners=False): + target_inds = targets_out[:, 0].int() + xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] + num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) + # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) + xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() + co = xyxy_corners + if corners: + co = targets_out[:, 2:6].int() + coords_map = torch.zeros_like(attr, dtype=torch.bool) + # rows = np.arange(co.shape[0]) + x1, x2 = co[:,1], co[:,3] + y1, y2 = co[:,0], co[:,2] + + for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop + coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True + + bbox_map = coords_map.to(torch.float32) + + return bbox_map +######################################## BCE ####################################### +def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False): + # if imgs is None: + # imgs = torch.zeros_like(attribution_map) + # Calculate Plausibility IoU with attribution maps + # attribution_map.retains_grad = True + if not only_loss: + plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs) + else: + plaus_score = torch.tensor(0.0) + + # attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch + + # Calculate distance regularization + distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff) + # distance_map = torch.ones_like(attribution_map) + + if opt.dist_x_bbox: + bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool) + distance_map[bbox_map] = 0.0 + # distance_map = distance_map * (1 - bbox_map) + + # Positive regularization term for incentivizing pixels near the target to have high attribution + dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map)) + # Negative regularization term for incentivizing pixels far from the target to have low attribution + dist_attr_neg = attr_reg(attribution_map, distance_map) + # Calculate plausibility regularization term + # dist_reg = dist_attr_pos - dist_attr_neg + dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map))) + # dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))) + # dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \ + # torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \ + # / 2.5 + + if opt.bbox_coeff != 0.0: + bbox_map = get_bbox_map(targets, attribution_map) + attr_bbox_pos = attr_reg(attribution_map, bbox_map) + attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map)) + bbox_reg = attr_bbox_pos - attr_bbox_neg + # bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map)) + else: + bbox_reg = 0.0 + + bbox_map = get_bbox_map(targets, attribution_map) + plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map))) + # iou_loss = (1.0 - plaus_score) + + if not opt.dist_reg_only: + dist_reg_loss = (((1.0 + dist_reg) / 2.0)) + plaus_reg = (plaus_score * opt.iou_coeff) + \ + (((dist_reg_loss * opt.dist_coeff) + \ + (bbox_reg * opt.bbox_coeff))\ + # ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\ + # / (plaus_score) \ + ) + else: + plaus_reg = (((1.0 + dist_reg) / 2.0)) + # plaus_reg = dist_reg + # Calculate plausibility loss + plaus_loss = (1 - plaus_reg) * opt.pgt_coeff + # plaus_loss = (plaus_reg) * opt.pgt_coeff + if only_loss: + return plaus_loss + if not debug: + return plaus_loss, (plaus_score, dist_reg, plaus_reg,) + else: + return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map + +#################################################################################### +#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS #### +#################################################################################### + +def generate_vanilla_grad(model, input_tensor, loss_func = None, + targets_list=None, targets=None, metric=None, out_num = 1, + n_max_labels=3, norm=True, abs=True, grayscale=True, + class_specific_attr = True, device='cpu'): + """ + Generate vanilla gradients for the given model and input tensor. + + Args: + model (nn.Module): The model to generate gradients for. + input_tensor (torch.Tensor): The input tensor for which gradients are computed. + loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None. + targets_list (list, optional): The list of target tensors. Defaults to None. + metric (callable, optional): The metric function to evaluate the loss. Defaults to None. + out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1. + n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3. + norm (bool, optional): Whether to normalize the attribution map. Defaults to True. + abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True. + grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True. + class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True. + device (str, optional): The device to use for computation. Defaults to 'cpu'. + + Returns: + torch.Tensor: The generated vanilla gradients. + """ + # Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end + train_mode = model.training + if not train_mode: + model.train() + + input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients + model.zero_grad() # Zero gradients + inpt = input_tensor + # Forward pass + train_out = model(inpt) # training outputs (no inference outputs in train mode) + + # train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities) + # train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling) + # train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence) + + if class_specific_attr: + n_attr_list, index_classes = [], [] + for i in range(len(input_tensor)): + if len(targets_list[i]) > n_max_labels: + targets_list[i] = targets_list[i][:n_max_labels] + if targets_list[i].numel() != 0: + # unique_classes = torch.unique(targets_list[i][:,1]) + class_numbers = targets_list[i][:,1] + index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers]) + num_attrs = len(targets_list[i]) + # index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes]) + # num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i]) + n_attr_list.append(num_attrs) + else: + index_classes.append([0, 1, 2, 3, 4]) + n_attr_list.append(0) + + targets_list_filled = [targ.clone().detach() for targ in targets_list] + labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))] + max_labels = np.max(labels_len) + max_index = np.argmax(labels_len) + for i in range(len(targets_list)): + # targets_list_filled[i] = targets_list[i] + if len(targets_list_filled[i]) < max_labels: + tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i])) + targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0) + else: + targets_list_filled[i] = targets_list_filled[i].unsqueeze(0) + for i in range(len(targets_list_filled)-1,-1,-1): + if targets_list_filled[i].numel() == 0: + targets_list_filled.pop(i) + targets_list_filled = torch.cat(targets_list_filled) + + n_img_attrs = len(input_tensor) if class_specific_attr else 1 + n_img_attrs = 1 if loss_func else n_img_attrs + + attrs_batch = [] + for i_batch in range(n_img_attrs): + if loss_func and class_specific_attr: + i_batch = max_index + # inpt = input_tensor[i_batch].unsqueeze(0) + # ################################################################## + # model.zero_grad() # Zero gradients + # train_out = model(inpt) # training outputs (no inference outputs in train mode) + # ################################################################## + n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1 + n_label_attrs = 1 if not class_specific_attr else n_label_attrs + attrs_img = [] + for i_attr in range(n_label_attrs): + if loss_func is None: + grad_wrt = train_out[out_num] + if class_specific_attr: + grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]] + grad_wrt_outputs = torch.ones_like(grad_wrt) + else: + # if class_specific_attr: + # targets = targets_list[:][i_attr] + # n_targets = len(targets_list[i_batch]) + if class_specific_attr: + target_indiv = targets_list_filled[:,i_attr] # batch image input + else: + target_indiv = targets + # target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input + # target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time + + try: + loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size + except: + target_indiv = target_indiv.to(device) + inpt = inpt.to(device) + for tro in train_out: + tro = tro.to(device) + print("Error in loss function, trying again with device specified") + loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) + grad_wrt = loss + grad_wrt_outputs = None + + model.zero_grad() # Zero gradients + gradients = torch.autograd.grad(grad_wrt, inpt, + grad_outputs=grad_wrt_outputs, + retain_graph=True, + # create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly + ) + + # Convert gradients to numpy array and back to ensure full separation from graph + # attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy()) + attribution_map = gradients[0]#.clone().detach() # without converting to numpy + + if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval + attribution_map = torch.sum(attribution_map, 1, keepdim=True) + if abs: + attribution_map = torch.abs(attribution_map) # Take absolute values of gradients + if norm: + attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch + attrs_img.append(attribution_map) + if len(attrs_img) == 0: + attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device)) + else: + attrs_batch.append(torch.stack(attrs_img).to(device)) + + # out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device) + # out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch + out_attr = attrs_batch + # Set model back to original mode + if not train_mode: + model.eval() + + return out_attr + +class RVNonLinearFunc(torch.nn.Module): + """ + Custom Bayesian ReLU activation function for random variables. + + Attributes: + None + """ + def __init__(self, func): + super(RVNonLinearFunc, self).__init__() + self.func = func + + def forward(self, mu_in, Sigma_in): + """ + Forward pass of the Bayesian ReLU activation function. + + Args: + mu_in (torch.Tensor): A tensor of shape (batch_size, input_size), + representing the mean input to the ReLU activation function. + Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size), + representing the covariance input to the ReLU activation function. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors, + including the mean of the output and the covariance of the output. + """ + # Collect stats + batch_size = mu_in.size(0) + + # Mean + mu_out = self.func(mu_in) + + # Compute the derivative of the ReLU activation function with respect to the input mean + gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1) + + # add an extra dimension to gradi at position 2 and 1 + grad1 = gradi.unsqueeze(dim=2) + grad2 = gradi.unsqueeze(dim=1) + + # compute the outer product of grad1 and grad2 + outer_product = torch.bmm(grad1, grad2) + + # element-wise multiply Sigma_in with the outer product + # and return the result + Sigma_out = torch.mul(Sigma_in, outer_product) + + return mu_out, Sigma_out + diff --git a/ultralytics/utils/plot_functs.py b/ultralytics/utils/plot_functs.py new file mode 100644 index 00000000..4b403331 --- /dev/null +++ b/ultralytics/utils/plot_functs.py @@ -0,0 +1,154 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn as nn + +class Subplots: + def __init__(self, figsize = (40, 5)): + self.fig = plt.figure(figsize=figsize) + + def plot_img_list(self, img_list, savedir='figs/test', + nrows = 1, rownum = 0, + hold = False, coltitles=[], rowtitle=''): + + for i, img in enumerate(img_list): + try: + npimg = img.clone().detach().cpu().numpy() + except: + npimg = img + tpimg = np.transpose(npimg, (1, 2, 0)) + lenrow = int((len(img_list))) + ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow)) + if len(coltitles) > i: + ax.set_title(coltitles[i]) + if i == 0: + ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0), + xycoords='axes fraction', textcoords='offset points', + size='large', ha='center', va='baseline') + # ax.set_ylabel(rowtitle, rotation=90) + ax.imshow(tpimg) + ax.axis('off') + + if not hold: + self.fig.tight_layout() + plt.savefig(f'{savedir}.png') + plt.clf() + plt.close('all') + + +def VisualizeNumpyImageGrayscale(image_3d): + r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. + """ + vmin = np.min(image_3d) + image_2d = image_3d - vmin + vmax = np.max(image_2d) + return (image_2d / vmax) + +def normalize_numpy(image_3d): + r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. + """ + vmin = np.min(image_3d) + image_2d = image_3d - vmin + vmax = np.max(image_2d) + return (image_2d / vmax) + +# def normalize_tensor(image_3d): +# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. +# """ +# vmin = torch.min(image_3d) +# image_2d = image_3d - vmin +# vmax = torch.max(image_2d) +# return (image_2d / vmax) + +def normalize_tensor(image_3d): + r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. + """ + image_2d = (image_3d - torch.min(image_3d)) + return (image_2d / torch.max(image_2d)) + +def format_img(img_): + np_img = img_.numpy() + tp_img = np.transpose(np_img, (1, 2, 0)) + return tp_img + +def imshow(img, save_path=None): + try: + npimg = img.clone().detach().cpu().numpy() + except: + npimg = img + tpimg = np.transpose(npimg, (1, 2, 0)) + plt.imshow(tpimg) + # plt.axis('off') + plt.tight_layout() + if save_path != None: + plt.savefig(str(str(save_path) + ".png")) + #plt.show()a + +def imshow_img(img, imsave_path): + # works for tensors and numpy arrays + try: + npimg = VisualizeNumpyImageGrayscale(img.numpy()) + except: + npimg = VisualizeNumpyImageGrayscale(img) + npimg = np.transpose(npimg, (2, 0, 1)) + imshow(npimg, save_path=imsave_path) + print("Saving image as ", imsave_path) + +def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'): + model.train() + model.to(device) + img = img.to(device) + img.requires_grad_(True) + labels.to(device).requires_grad_(True) + model.requires_grad_(True) + cuda = device.type != 'cpu' + scaler = amp.GradScaler(enabled=cuda) + pred = model(img) + # out, train_out = model(img, augment=augment) # inference and training outputs + loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls + # loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device)) + # loss = torch.sum(loss).requires_grad_(True) + + with torch.autograd.set_detect_anomaly(True): + scaler.scale(loss).backward(inputs=img) + # loss.backward() + +# S_c = torch.max(pred[0].data, 0)[0] + Sc_dx = img.grad + model.eval() + Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32) + return Sc_dx + +def calculate_snr(img, attr, dB=True): + try: + img_np = img.detach().cpu().numpy() + attr_np = attr.detach().cpu().numpy() + except: + img_np = img + attr_np = attr + + # Calculate the signal power + signal_power = np.mean(img_np**2) + + # Calculate the noise power + noise_power = np.mean(attr_np**2) + + if dB == True: + # Calculate SNR in dB + snr = 10 * np.log10(signal_power / noise_power) + else: + # Calculate SNR + snr = signal_power / noise_power + + return snr + +def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7): + + cmap = plt.get_cmap(colormap) + npmask = np.array(mask.clone().detach().cpu().squeeze(0)) + # cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1)) + cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1)) + overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask)) + overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device) + + return overlayed_tensor \ No newline at end of file diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index d0215ba5..0e99c722 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -717,7 +717,7 @@ def plot_images( ): """Plot image grid with labels.""" if isinstance(images, torch.Tensor): - images = images.cpu().float().numpy() + images = images.detach().cpu().float().numpy() if isinstance(cls, torch.Tensor): cls = cls.cpu().numpy() if isinstance(bboxes, torch.Tensor):