mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 14:35:40 +08:00 
			
		
		
		
	PGT Training now functioning. Run run_pgt_train.py to train the model.
This commit is contained in:
		
							parent
							
								
									2a95a652bd
								
							
						
					
					
						commit
						38fa59edf2
					
				| @ -3,26 +3,41 @@ from ultralytics import YOLOv10, YOLO | |||||||
| # from ultralytics import BaseTrainer | # from ultralytics import BaseTrainer | ||||||
| # from ultralytics.engine.trainer import BaseTrainer | # from ultralytics.engine.trainer import BaseTrainer | ||||||
| import os | import os | ||||||
|  | from ultralytics.models.yolo.segment import PGTSegmentationTrainer | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # Set CUDA device (only needed for multi-gpu machines)  | # Set CUDA device (only needed for multi-gpu machines)  | ||||||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  | ||||||
| os.environ["CUDA_VISIBLE_DEVICES"] = "4"  | os.environ["CUDA_VISIBLE_DEVICES"] = "4"  | ||||||
| 
 | 
 | ||||||
| # model = YOLOv10() | # model = YOLOv10() | ||||||
|  | # model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt')  # build from YAML and transfer weights | ||||||
|  | 
 | ||||||
| # model = YOLO() | # model = YOLO() | ||||||
| # If you want to finetune the model with pretrained weights, you could load the  | # If you want to finetune the model with pretrained weights, you could load the  | ||||||
| # pretrained weights like below | # pretrained weights like below | ||||||
| # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') | # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') | ||||||
| # or | # or | ||||||
| # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt | # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt | ||||||
| model = YOLOv10('yolov10n.pt') | model = YOLOv10('yolov10n.pt', task='segment') | ||||||
| 
 | 
 | ||||||
| model.train(data='coco.yaml',  | args = dict(model='yolov10n.pt', data='coco128-seg.yaml') | ||||||
|             trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) | trainer = PGTSegmentationTrainer(overrides=args) | ||||||
|             # Add return_images as input parameter | trainer.train( | ||||||
|             epochs=500, batch=16, imgsz=640, |             # debug=True,  | ||||||
|             debug=True, # If debug = True, the attributions will be saved in the figures folder |             #   args = dict(pgt_coeff=0.1), | ||||||
|             ) |               ) | ||||||
|  | 
 | ||||||
|  | # model.train( | ||||||
|  | #             # data='coco.yaml',  | ||||||
|  | #             data='coco128-seg.yaml',  | ||||||
|  | #             trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) | ||||||
|  | #             # Add return_images as input parameter | ||||||
|  | #             epochs=500, batch=16, imgsz=640, | ||||||
|  | #             debug=True, # If debug = True, the attributions will be saved in the figures folder | ||||||
|  | #             # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml', | ||||||
|  | #             # overrides=dict(task="segment"), | ||||||
|  | #             ) | ||||||
| 
 | 
 | ||||||
| # Save the trained model | # Save the trained model | ||||||
| model.save('yolov10_coco_trained.pt') | model.save('yolov10_coco_trained.pt') | ||||||
|  | |||||||
							
								
								
									
										47
									
								
								run_pgt_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								run_pgt_train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | |||||||
|  | from ultralytics import YOLOv10, YOLO | ||||||
|  | # from ultralytics.engine.pgt_trainer import PGTTrainer | ||||||
|  | import os | ||||||
|  | from ultralytics.models.yolo.segment import PGTSegmentationTrainer | ||||||
|  | import argparse | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(args): | ||||||
|  |   # model = YOLOv10() | ||||||
|  | 
 | ||||||
|  |   # If you want to finetune the model with pretrained weights, you could load the  | ||||||
|  |   # pretrained weights like below | ||||||
|  |   # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') | ||||||
|  |   # or | ||||||
|  |   # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt | ||||||
|  |   model = YOLOv10('yolov10n.pt', task='segment') | ||||||
|  | 
 | ||||||
|  |   args = dict(model='yolov10n.pt', data='coco.yaml',  | ||||||
|  |               epochs=args.epochs, batch=args.batch_size, | ||||||
|  |               # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process | ||||||
|  |               ) | ||||||
|  |   trainer = PGTSegmentationTrainer(overrides=args) | ||||||
|  |   trainer.train( | ||||||
|  |         # debug=True,  | ||||||
|  |         #   args = dict(pgt_coeff=0.1), # Should add later to config | ||||||
|  |           ) | ||||||
|  | 
 | ||||||
|  |   # Save the trained model | ||||||
|  |   model.save('yolov10_coco_trained.pt') | ||||||
|  | 
 | ||||||
|  |   # Evaluate the model on the validation set | ||||||
|  |   results = model.val(data='coco.yaml') | ||||||
|  | 
 | ||||||
|  |   # Print the evaluation results | ||||||
|  |   print(results) | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |   parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.') | ||||||
|  |   parser.add_argument('--device', type=str, default='0', help='CUDA device number') | ||||||
|  |   parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') | ||||||
|  |   parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training') | ||||||
|  |   args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |   # Set CUDA device (only needed for multi-gpu machines) | ||||||
|  |   os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||||
|  |   os.environ["CUDA_VISIBLE_DEVICES"] = args.device | ||||||
|  |   main(args) | ||||||
| @ -656,7 +656,10 @@ class Model(nn.Module): | |||||||
|                     pass |                     pass | ||||||
| 
 | 
 | ||||||
|         self.trainer.hub_session = self.session  # attach optional HUB session |         self.trainer.hub_session = self.session  # attach optional HUB session | ||||||
|         self.trainer.train(debug=debug) |         if debug: | ||||||
|  |             self.trainer.train(debug=debug) | ||||||
|  |         else: | ||||||
|  |             self.trainer.train() | ||||||
|         # Update model and cfg after training |         # Update model and cfg after training | ||||||
|         if RANK in (-1, 0): |         if RANK in (-1, 0): | ||||||
|             ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last |             ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last | ||||||
|  | |||||||
| @ -52,7 +52,9 @@ from ultralytics.utils.torch_utils import ( | |||||||
|     strip_optimizer, |     strip_optimizer, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | from ultralytics.utils.loss import v8DetectionLoss | ||||||
|  | from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn | ||||||
|  | import matplotlib.path as matplotlib_path | ||||||
| class PGTTrainer: | class PGTTrainer: | ||||||
|     """ |     """ | ||||||
|     BaseTrainer. |     BaseTrainer. | ||||||
| @ -159,6 +161,7 @@ class PGTTrainer: | |||||||
|         self.loss_names = ["Loss"] |         self.loss_names = ["Loss"] | ||||||
|         self.csv = self.save_dir / "results.csv" |         self.csv = self.save_dir / "results.csv" | ||||||
|         self.plot_idx = [0, 1, 2] |         self.plot_idx = [0, 1, 2] | ||||||
|  |         self.num = int(time.time()) | ||||||
| 
 | 
 | ||||||
|         # Callbacks |         # Callbacks | ||||||
|         self.callbacks = _callbacks or callbacks.get_default_callbacks() |         self.callbacks = _callbacks or callbacks.get_default_callbacks() | ||||||
| @ -328,7 +331,7 @@ class PGTTrainer: | |||||||
|         if world_size > 1: |         if world_size > 1: | ||||||
|             self._setup_ddp(world_size) |             self._setup_ddp(world_size) | ||||||
|         self._setup_train(world_size) |         self._setup_train(world_size) | ||||||
| 
 |          | ||||||
|         nb = len(self.train_loader)  # number of batches |         nb = len(self.train_loader)  # number of batches | ||||||
|         nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations |         nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations | ||||||
|         last_opt_step = -1 |         last_opt_step = -1 | ||||||
| @ -382,37 +385,88 @@ class PGTTrainer: | |||||||
|                 with torch.cuda.amp.autocast(self.amp): |                 with torch.cuda.amp.autocast(self.amp): | ||||||
|                     batch = self.preprocess_batch(batch) |                     batch = self.preprocess_batch(batch) | ||||||
|                     (self.loss, self.loss_items), images = self.model(batch, return_images=True) |                     (self.loss, self.loss_items), images = self.model(batch, return_images=True) | ||||||
|  | 
 | ||||||
|  |                 # smask = get_dist_reg(images, batch['masks']) | ||||||
|  | 
 | ||||||
|  |                 # grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0] | ||||||
|  |                 # grad = torch.abs(grad) | ||||||
|  | 
 | ||||||
|  |                 # self.args.pgt_coeff = 1.1 | ||||||
|  |                 # plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff) | ||||||
|  |                 # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0))) | ||||||
|  |                 # self.loss += plaus_loss | ||||||
|  | 
 | ||||||
|  |                 debug_ = debug | ||||||
|  |                 if debug_ and (i % 25 == 0): | ||||||
|  |                     debug_ = False | ||||||
|  |                     # Create a tensor of zeros with the same size as images | ||||||
|  |                     mask = torch.zeros_like(images, dtype=torch.float32) | ||||||
|  |                     smask = get_dist_reg(images, batch['masks']) | ||||||
|  |                     grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0] | ||||||
|  |                     grad = torch.abs(grad) | ||||||
|  | 
 | ||||||
|  |                     batch_size = images.shape[0] | ||||||
|  |                     imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device) | ||||||
|  |                     targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) | ||||||
|  |                     targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) | ||||||
|  |                     gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy | ||||||
|  |                     mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) | ||||||
|                      |                      | ||||||
|                     if debug and (i % 250): |                     # Iterate over each bounding box and set the corresponding pixels to 1 | ||||||
|                         grad = torch.autograd.grad(self.loss, images, create_graph=True)[0] |                     for irx, bboxes in enumerate(gt_bboxes): | ||||||
|  |                         for idx in range(len(bboxes)): | ||||||
|  |                             x1, y1, x2, y2 = bboxes[idx] | ||||||
|  |                             x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) | ||||||
|  |                             mask[irx, :, y1:y2, x1:x2] = 1.0 | ||||||
|  | 
 | ||||||
|  |                     save_imgs = True | ||||||
|  |                     if save_imgs: | ||||||
|                         # Convert tensors to numpy arrays |                         # Convert tensors to numpy arrays | ||||||
|                         images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) |                         images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) | ||||||
|                         grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) |                         grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) | ||||||
|  |                         mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1) | ||||||
|  |                         seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1) | ||||||
| 
 | 
 | ||||||
|                         # Normalize grad for visualization |                         # Normalize grad for visualization | ||||||
|                         grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) |                         grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) | ||||||
| 
 |                          | ||||||
|                         for ix in range(images_np.shape[0]): |                         for ix in range(images_np.shape[0]): | ||||||
|                             fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |                             fig, ax = plt.subplots(1, 6, figsize=(30, 5)) | ||||||
|                             ax[0].imshow(images_np[i]) |                             ax[0].imshow(images_np[ix]) | ||||||
|                             ax[0].set_title('Image') |                             ax[0].set_title('Image') | ||||||
|                             ax[1].imshow(grad_np[i], cmap='jet') |                             ax[1].imshow(grad_np[ix], cmap='jet') | ||||||
|                             ax[1].set_title('Gradient') |                             ax[1].set_title('Gradient') | ||||||
|                             ax[2].imshow(images_np[i]) |                             ax[2].imshow(images_np[ix]) | ||||||
|                             ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5) |                             ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5) | ||||||
|                             ax[2].set_title('Overlay') |                             ax[2].set_title('Overlay') | ||||||
|  |                             ax[3].imshow(mask_np[ix], cmap='gray') | ||||||
|  |                             ax[3].set_title('Mask') | ||||||
|  |                             ax[4].imshow(seg_mask_np[ix], cmap='gray') | ||||||
|  |                             ax[4].set_title('Segmentation Mask') | ||||||
|  |                              | ||||||
|  |                             # Plot image with bounding boxes | ||||||
|  |                             ax[5].imshow(images_np[ix]) | ||||||
|  |                             for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]): | ||||||
|  |                                 x1, y1, x2, y2 = bbox | ||||||
|  |                                 x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) | ||||||
|  |                                 rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2) | ||||||
|  |                                 ax[5].add_patch(rect) | ||||||
|  |                                 ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5)) | ||||||
|  |                             ax[5].set_title('Bounding Boxes') | ||||||
| 
 | 
 | ||||||
|                             save_dir_attr = "figures/attributions" |                             save_dir_attr = f"figures/attributions/run{self.num}" | ||||||
|                             if not os.path.exists(save_dir_attr): |                             if not os.path.exists(save_dir_attr): | ||||||
|                                 os.makedirs(save_dir_attr) |                                 os.makedirs(save_dir_attr) | ||||||
|                             plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') |                             plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') | ||||||
|                             plt.close(fig) |                             plt.close(fig) | ||||||
|                      |                     images = images.detach() | ||||||
|                     if RANK != -1: |                  | ||||||
|                         self.loss *= world_size |                  | ||||||
|                     self.tloss = ( |                 if RANK != -1: | ||||||
|                         (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items |                     self.loss *= world_size | ||||||
|                     ) |                 self.tloss = ( | ||||||
|  |                     (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|                 # Backward |                 # Backward | ||||||
|                 self.scaler.scale(self.loss).backward() |                 self.scaler.scale(self.loss).backward() | ||||||
|  | |||||||
							
								
								
									
										345
									
								
								ultralytics/engine/pgt_validator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										345
									
								
								ultralytics/engine/pgt_validator.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,345 @@ | |||||||
|  | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
|  | """ | ||||||
|  | Check a model's accuracy on a test or val split of a dataset. | ||||||
|  | 
 | ||||||
|  | Usage: | ||||||
|  |     $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640 | ||||||
|  | 
 | ||||||
|  | Usage - formats: | ||||||
|  |     $ yolo mode=val model=yolov8n.pt                 # PyTorch | ||||||
|  |                           yolov8n.torchscript        # TorchScript | ||||||
|  |                           yolov8n.onnx               # ONNX Runtime or OpenCV DNN with dnn=True | ||||||
|  |                           yolov8n_openvino_model     # OpenVINO | ||||||
|  |                           yolov8n.engine             # TensorRT | ||||||
|  |                           yolov8n.mlpackage          # CoreML (macOS-only) | ||||||
|  |                           yolov8n_saved_model        # TensorFlow SavedModel | ||||||
|  |                           yolov8n.pb                 # TensorFlow GraphDef | ||||||
|  |                           yolov8n.tflite             # TensorFlow Lite | ||||||
|  |                           yolov8n_edgetpu.tflite     # TensorFlow Edge TPU | ||||||
|  |                           yolov8n_paddle_model       # PaddlePaddle | ||||||
|  |                           yolov8n_ncnn_model         # NCNN | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import json | ||||||
|  | import time | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from ultralytics.cfg import get_cfg, get_save_dir | ||||||
|  | from ultralytics.data.utils import check_cls_dataset, check_det_dataset | ||||||
|  | from ultralytics.nn.autobackend import AutoBackend | ||||||
|  | from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis | ||||||
|  | from ultralytics.utils.checks import check_imgsz | ||||||
|  | from ultralytics.utils.ops import Profile | ||||||
|  | from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PGTValidator: | ||||||
|  |     """ | ||||||
|  |     BaseValidator. | ||||||
|  | 
 | ||||||
|  |     A base class for creating validators. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         args (SimpleNamespace): Configuration for the validator. | ||||||
|  |         dataloader (DataLoader): Dataloader to use for validation. | ||||||
|  |         pbar (tqdm): Progress bar to update during validation. | ||||||
|  |         model (nn.Module): Model to validate. | ||||||
|  |         data (dict): Data dictionary. | ||||||
|  |         device (torch.device): Device to use for validation. | ||||||
|  |         batch_i (int): Current batch index. | ||||||
|  |         training (bool): Whether the model is in training mode. | ||||||
|  |         names (dict): Class names. | ||||||
|  |         seen: Records the number of images seen so far during validation. | ||||||
|  |         stats: Placeholder for statistics during validation. | ||||||
|  |         confusion_matrix: Placeholder for a confusion matrix. | ||||||
|  |         nc: Number of classes. | ||||||
|  |         iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05. | ||||||
|  |         jdict (dict): Dictionary to store JSON validation results. | ||||||
|  |         speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective | ||||||
|  |                       batch processing times in milliseconds. | ||||||
|  |         save_dir (Path): Directory to save results. | ||||||
|  |         plots (dict): Dictionary to store plots for visualization. | ||||||
|  |         callbacks (dict): Dictionary to store various callback functions. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): | ||||||
|  |         """ | ||||||
|  |         Initializes a BaseValidator instance. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. | ||||||
|  |             save_dir (Path, optional): Directory to save results. | ||||||
|  |             pbar (tqdm.tqdm): Progress bar for displaying progress. | ||||||
|  |             args (SimpleNamespace): Configuration for the validator. | ||||||
|  |             _callbacks (dict): Dictionary to store various callback functions. | ||||||
|  |         """ | ||||||
|  |         self.args = get_cfg(overrides=args) | ||||||
|  |         self.dataloader = dataloader | ||||||
|  |         self.pbar = pbar | ||||||
|  |         self.stride = None | ||||||
|  |         self.data = None | ||||||
|  |         self.device = None | ||||||
|  |         self.batch_i = None | ||||||
|  |         self.training = True | ||||||
|  |         self.names = None | ||||||
|  |         self.seen = None | ||||||
|  |         self.stats = None | ||||||
|  |         self.confusion_matrix = None | ||||||
|  |         self.nc = None | ||||||
|  |         self.iouv = None | ||||||
|  |         self.jdict = None | ||||||
|  |         self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} | ||||||
|  | 
 | ||||||
|  |         self.save_dir = save_dir or get_save_dir(self.args) | ||||||
|  |         (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||||
|  |         if self.args.conf is None: | ||||||
|  |             self.args.conf = 0.001  # default conf=0.001 | ||||||
|  |         self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) | ||||||
|  | 
 | ||||||
|  |         self.plots = {} | ||||||
|  |         self.callbacks = _callbacks or callbacks.get_default_callbacks() | ||||||
|  | 
 | ||||||
|  |     # @smart_inference_mode() | ||||||
|  |     def __call__(self, trainer=None, model=None): | ||||||
|  |         """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer | ||||||
|  |         gets priority). | ||||||
|  |         """ | ||||||
|  |         self.training = trainer is not None | ||||||
|  |         augment = self.args.augment and (not self.training) | ||||||
|  |         if self.training: | ||||||
|  |             self.device = trainer.device | ||||||
|  |             self.data = trainer.data | ||||||
|  |             # self.args.half = self.device.type != "cpu"  # force FP16 val during training | ||||||
|  |             model = trainer.ema.ema or trainer.model | ||||||
|  |             model = model.half() if self.args.half else model.float() | ||||||
|  |             # self.model = model | ||||||
|  |             self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) | ||||||
|  |             self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1) | ||||||
|  |             model.eval() | ||||||
|  |         else: | ||||||
|  |             callbacks.add_integration_callbacks(self) | ||||||
|  |             model = AutoBackend( | ||||||
|  |                 weights=model or self.args.model, | ||||||
|  |                 device=select_device(self.args.device, self.args.batch), | ||||||
|  |                 dnn=self.args.dnn, | ||||||
|  |                 data=self.args.data, | ||||||
|  |                 fp16=self.args.half, | ||||||
|  |             ) | ||||||
|  |             # self.model = model | ||||||
|  |             self.device = model.device  # update device | ||||||
|  |             self.args.half = model.fp16  # update half | ||||||
|  |             stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine | ||||||
|  |             imgsz = check_imgsz(self.args.imgsz, stride=stride) | ||||||
|  |             if engine: | ||||||
|  |                 self.args.batch = model.batch_size | ||||||
|  |             elif not pt and not jit: | ||||||
|  |                 self.args.batch = 1  # export.py models default to batch-size 1 | ||||||
|  |                 LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") | ||||||
|  | 
 | ||||||
|  |             if str(self.args.data).split(".")[-1] in ("yaml", "yml"): | ||||||
|  |                 self.data = check_det_dataset(self.args.data) | ||||||
|  |             elif self.args.task == "classify": | ||||||
|  |                 self.data = check_cls_dataset(self.args.data, split=self.args.split) | ||||||
|  |             else: | ||||||
|  |                 raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) | ||||||
|  | 
 | ||||||
|  |             if self.device.type in ("cpu", "mps"): | ||||||
|  |                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading | ||||||
|  |             if not pt: | ||||||
|  |                 self.args.rect = False | ||||||
|  |             self.stride = model.stride  # used in get_dataloader() for padding | ||||||
|  |             self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) | ||||||
|  | 
 | ||||||
|  |             model.eval() | ||||||
|  |             model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup | ||||||
|  | 
 | ||||||
|  |         self.run_callbacks("on_val_start") | ||||||
|  |         dt = ( | ||||||
|  |             Profile(device=self.device), | ||||||
|  |             Profile(device=self.device), | ||||||
|  |             Profile(device=self.device), | ||||||
|  |             Profile(device=self.device), | ||||||
|  |         ) | ||||||
|  |         bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) | ||||||
|  |         self.init_metrics(de_parallel(model)) | ||||||
|  |         self.jdict = []  # empty before each val | ||||||
|  |         for batch_i, batch in enumerate(bar): | ||||||
|  |             self.run_callbacks("on_val_batch_start") | ||||||
|  |             self.batch_i = batch_i | ||||||
|  |             # Preprocess | ||||||
|  |             with dt[0]: | ||||||
|  |                 batch = self.preprocess(batch) | ||||||
|  | 
 | ||||||
|  |             # Inference | ||||||
|  |             with dt[1]: | ||||||
|  |                 preds = model(batch["img"].requires_grad_(True), augment=augment) | ||||||
|  | 
 | ||||||
|  |             # Loss | ||||||
|  |             with dt[2]: | ||||||
|  |                 if self.training: | ||||||
|  |                     self.loss += model.loss(batch, preds)[1] | ||||||
|  | 
 | ||||||
|  |             # Postprocess | ||||||
|  |             with dt[3]: | ||||||
|  |                 preds = self.postprocess(preds) | ||||||
|  | 
 | ||||||
|  |             self.update_metrics(preds, batch) | ||||||
|  |             if self.args.plots and batch_i < 3: | ||||||
|  |                 self.plot_val_samples(batch, batch_i) | ||||||
|  |                 self.plot_predictions(batch, preds, batch_i) | ||||||
|  | 
 | ||||||
|  |             self.run_callbacks("on_val_batch_end") | ||||||
|  |         stats = self.get_stats() | ||||||
|  |         self.check_stats(stats) | ||||||
|  |         self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) | ||||||
|  |         self.finalize_metrics() | ||||||
|  |         if not (self.args.save_json and self.is_coco and len(self.jdict)): | ||||||
|  |             self.print_results() | ||||||
|  |         self.run_callbacks("on_val_end") | ||||||
|  |         if self.training: | ||||||
|  |             model.float() | ||||||
|  |             if self.args.save_json and self.jdict: | ||||||
|  |                 with open(str(self.save_dir / "predictions.json"), "w") as f: | ||||||
|  |                     LOGGER.info(f"Saving {f.name}...") | ||||||
|  |                     json.dump(self.jdict, f)  # flatten and save | ||||||
|  |                 stats = self.eval_json(stats)  # update stats | ||||||
|  |                 stats['fitness'] = stats['metrics/mAP50-95(B)'] | ||||||
|  |             results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} | ||||||
|  |             return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats | ||||||
|  |         else: | ||||||
|  |             LOGGER.info( | ||||||
|  |                 "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" | ||||||
|  |                 % tuple(self.speed.values()) | ||||||
|  |             ) | ||||||
|  |             if self.args.save_json and self.jdict: | ||||||
|  |                 with open(str(self.save_dir / "predictions.json"), "w") as f: | ||||||
|  |                     LOGGER.info(f"Saving {f.name}...") | ||||||
|  |                     json.dump(self.jdict, f)  # flatten and save | ||||||
|  |                 stats = self.eval_json(stats)  # update stats | ||||||
|  |             if self.args.plots or self.args.save_json: | ||||||
|  |                 LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") | ||||||
|  |             return stats | ||||||
|  | 
 | ||||||
|  |     def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): | ||||||
|  |         """ | ||||||
|  |         Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             pred_classes (torch.Tensor): Predicted class indices of shape(N,). | ||||||
|  |             true_classes (torch.Tensor): Target class indices of shape(M,). | ||||||
|  |             iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth | ||||||
|  |             use_scipy (bool): Whether to use scipy for matching (more precise). | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. | ||||||
|  |         """ | ||||||
|  |         # Dx10 matrix, where D - detections, 10 - IoU thresholds | ||||||
|  |         correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) | ||||||
|  |         # LxD matrix where L - labels (rows), D - detections (columns) | ||||||
|  |         correct_class = true_classes[:, None] == pred_classes | ||||||
|  |         iou = iou * correct_class  # zero out the wrong classes | ||||||
|  |         iou = iou.cpu().numpy() | ||||||
|  |         for i, threshold in enumerate(self.iouv.cpu().tolist()): | ||||||
|  |             if use_scipy: | ||||||
|  |                 # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 | ||||||
|  |                 import scipy  # scope import to avoid importing for all commands | ||||||
|  | 
 | ||||||
|  |                 cost_matrix = iou * (iou >= threshold) | ||||||
|  |                 if cost_matrix.any(): | ||||||
|  |                     labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True) | ||||||
|  |                     valid = cost_matrix[labels_idx, detections_idx] > 0 | ||||||
|  |                     if valid.any(): | ||||||
|  |                         correct[detections_idx[valid], i] = True | ||||||
|  |             else: | ||||||
|  |                 matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match | ||||||
|  |                 matches = np.array(matches).T | ||||||
|  |                 if matches.shape[0]: | ||||||
|  |                     if matches.shape[0] > 1: | ||||||
|  |                         matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] | ||||||
|  |                         matches = matches[np.unique(matches[:, 1], return_index=True)[1]] | ||||||
|  |                         # matches = matches[matches[:, 2].argsort()[::-1]] | ||||||
|  |                         matches = matches[np.unique(matches[:, 0], return_index=True)[1]] | ||||||
|  |                     correct[matches[:, 1].astype(int), i] = True | ||||||
|  |         return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) | ||||||
|  | 
 | ||||||
|  |     def add_callback(self, event: str, callback): | ||||||
|  |         """Appends the given callback.""" | ||||||
|  |         self.callbacks[event].append(callback) | ||||||
|  | 
 | ||||||
|  |     def run_callbacks(self, event: str): | ||||||
|  |         """Runs all callbacks associated with a specified event.""" | ||||||
|  |         for callback in self.callbacks.get(event, []): | ||||||
|  |             callback(self) | ||||||
|  | 
 | ||||||
|  |     def get_dataloader(self, dataset_path, batch_size): | ||||||
|  |         """Get data loader from dataset path and batch size.""" | ||||||
|  |         raise NotImplementedError("get_dataloader function not implemented for this validator") | ||||||
|  | 
 | ||||||
|  |     def build_dataset(self, img_path): | ||||||
|  |         """Build dataset.""" | ||||||
|  |         raise NotImplementedError("build_dataset function not implemented in validator") | ||||||
|  | 
 | ||||||
|  |     def preprocess(self, batch): | ||||||
|  |         """Preprocesses an input batch.""" | ||||||
|  |         return batch | ||||||
|  | 
 | ||||||
|  |     def postprocess(self, preds): | ||||||
|  |         """Describes and summarizes the purpose of 'postprocess()' but no details mentioned.""" | ||||||
|  |         return preds | ||||||
|  | 
 | ||||||
|  |     def init_metrics(self, model): | ||||||
|  |         """Initialize performance metrics for the YOLO model.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def update_metrics(self, preds, batch): | ||||||
|  |         """Updates metrics based on predictions and batch.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def finalize_metrics(self, *args, **kwargs): | ||||||
|  |         """Finalizes and returns all metrics.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def get_stats(self): | ||||||
|  |         """Returns statistics about the model's performance.""" | ||||||
|  |         return {} | ||||||
|  | 
 | ||||||
|  |     def check_stats(self, stats): | ||||||
|  |         """Checks statistics.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def print_results(self): | ||||||
|  |         """Prints the results of the model's predictions.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def get_desc(self): | ||||||
|  |         """Get description of the YOLO model.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def metric_keys(self): | ||||||
|  |         """Returns the metric keys used in YOLO training/validation.""" | ||||||
|  |         return [] | ||||||
|  | 
 | ||||||
|  |     def on_plot(self, name, data=None): | ||||||
|  |         """Registers plots (e.g. to be consumed in callbacks)""" | ||||||
|  |         self.plots[Path(name)] = {"data": data, "timestamp": time.time()} | ||||||
|  | 
 | ||||||
|  |     # TODO: may need to put these following functions into callback | ||||||
|  |     def plot_val_samples(self, batch, ni): | ||||||
|  |         """Plots validation samples during training.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def plot_predictions(self, batch, preds, ni): | ||||||
|  |         """Plots YOLO model predictions on batch images.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def pred_to_json(self, preds, batch): | ||||||
|  |         """Convert predictions to JSON format.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def eval_json(self, stats): | ||||||
|  |         """Evaluate and return JSON format of prediction statistics.""" | ||||||
|  |         pass | ||||||
| @ -4,5 +4,6 @@ from .predict import DetectionPredictor | |||||||
| from .pgt_train import PGTDetectionTrainer | from .pgt_train import PGTDetectionTrainer | ||||||
| from .train import DetectionTrainer | from .train import DetectionTrainer | ||||||
| from .val import DetectionValidator | from .val import DetectionValidator | ||||||
|  | from .pgt_val import PGTDetectionValidator | ||||||
| 
 | 
 | ||||||
| __all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer" | __all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator" | ||||||
|  | |||||||
							
								
								
									
										300
									
								
								ultralytics/models/yolo/detect/pgt_val.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								ultralytics/models/yolo/detect/pgt_val.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,300 @@ | |||||||
|  | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from ultralytics.data import build_dataloader, build_yolo_dataset, converter | ||||||
|  | from ultralytics.engine.validator import BaseValidator | ||||||
|  | from ultralytics.engine.pgt_validator import PGTValidator | ||||||
|  | from ultralytics.utils import LOGGER, ops | ||||||
|  | from ultralytics.utils.checks import check_requirements | ||||||
|  | from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou | ||||||
|  | from ultralytics.utils.plotting import output_to_target, plot_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PGTDetectionValidator(PGTValidator): | ||||||
|  |     """ | ||||||
|  |     A class extending the BaseValidator class for validation based on a detection model. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         from ultralytics.models.yolo.detect import DetectionValidator | ||||||
|  | 
 | ||||||
|  |         args = dict(model='yolov8n.pt', data='coco8.yaml') | ||||||
|  |         validator = DetectionValidator(args=args) | ||||||
|  |         validator() | ||||||
|  |         ``` | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): | ||||||
|  |         """Initialize detection model with necessary variables and settings.""" | ||||||
|  |         super().__init__(dataloader, save_dir, pbar, args, _callbacks) | ||||||
|  |         self.nt_per_class = None | ||||||
|  |         self.is_coco = False | ||||||
|  |         self.class_map = None | ||||||
|  |         self.args.task = "detect" | ||||||
|  |         self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) | ||||||
|  |         self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95 | ||||||
|  |         self.niou = self.iouv.numel() | ||||||
|  |         self.lb = []  # for autolabelling | ||||||
|  | 
 | ||||||
|  |     def preprocess(self, batch): | ||||||
|  |         """Preprocesses batch of images for YOLO training.""" | ||||||
|  |         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||||
|  |         batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 | ||||||
|  |         for k in ["batch_idx", "cls", "bboxes"]: | ||||||
|  |             batch[k] = batch[k].to(self.device) | ||||||
|  | 
 | ||||||
|  |         if self.args.save_hybrid: | ||||||
|  |             height, width = batch["img"].shape[2:] | ||||||
|  |             nb = len(batch["img"]) | ||||||
|  |             bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) | ||||||
|  |             self.lb = ( | ||||||
|  |                 [ | ||||||
|  |                     torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) | ||||||
|  |                     for i in range(nb) | ||||||
|  |                 ] | ||||||
|  |                 if self.args.save_hybrid | ||||||
|  |                 else [] | ||||||
|  |             )  # for autolabelling | ||||||
|  | 
 | ||||||
|  |         return batch | ||||||
|  | 
 | ||||||
|  |     def init_metrics(self, model): | ||||||
|  |         """Initialize evaluation metrics for YOLO.""" | ||||||
|  |         val = self.data.get(self.args.split, "")  # validation path | ||||||
|  |         self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt")  # is COCO | ||||||
|  |         self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||||
|  |         self.args.save_json |= self.is_coco  # run on final val if training COCO | ||||||
|  |         self.names = model.names | ||||||
|  |         self.nc = len(model.names) | ||||||
|  |         self.metrics.names = self.names | ||||||
|  |         self.metrics.plot = self.args.plots | ||||||
|  |         self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) | ||||||
|  |         self.seen = 0 | ||||||
|  |         self.jdict = [] | ||||||
|  |         self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[]) | ||||||
|  | 
 | ||||||
|  |     def get_desc(self): | ||||||
|  |         """Return a formatted string summarizing class metrics of YOLO model.""" | ||||||
|  |         return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") | ||||||
|  | 
 | ||||||
|  |     def postprocess(self, preds): | ||||||
|  |         """Apply Non-maximum suppression to prediction outputs.""" | ||||||
|  |         return ops.non_max_suppression( | ||||||
|  |             preds, | ||||||
|  |             self.args.conf, | ||||||
|  |             self.args.iou, | ||||||
|  |             labels=self.lb, | ||||||
|  |             multi_label=True, | ||||||
|  |             agnostic=self.args.single_cls, | ||||||
|  |             max_det=self.args.max_det, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def _prepare_batch(self, si, batch): | ||||||
|  |         """Prepares a batch of images and annotations for validation.""" | ||||||
|  |         idx = batch["batch_idx"] == si | ||||||
|  |         cls = batch["cls"][idx].squeeze(-1) | ||||||
|  |         bbox = batch["bboxes"][idx] | ||||||
|  |         ori_shape = batch["ori_shape"][si] | ||||||
|  |         imgsz = batch["img"].shape[2:] | ||||||
|  |         ratio_pad = batch["ratio_pad"][si] | ||||||
|  |         if len(cls): | ||||||
|  |             bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes | ||||||
|  |             ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels | ||||||
|  |         return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad) | ||||||
|  | 
 | ||||||
|  |     def _prepare_pred(self, pred, pbatch): | ||||||
|  |         """Prepares a batch of images and annotations for validation.""" | ||||||
|  |         predn = pred.clone() | ||||||
|  |         ops.scale_boxes( | ||||||
|  |             pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] | ||||||
|  |         )  # native-space pred | ||||||
|  |         return predn | ||||||
|  | 
 | ||||||
|  |     def update_metrics(self, preds, batch): | ||||||
|  |         """Metrics.""" | ||||||
|  |         for si, pred in enumerate(preds): | ||||||
|  |             self.seen += 1 | ||||||
|  |             npr = len(pred) | ||||||
|  |             stat = dict( | ||||||
|  |                 conf=torch.zeros(0, device=self.device), | ||||||
|  |                 pred_cls=torch.zeros(0, device=self.device), | ||||||
|  |                 tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), | ||||||
|  |             ) | ||||||
|  |             pbatch = self._prepare_batch(si, batch) | ||||||
|  |             cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") | ||||||
|  |             nl = len(cls) | ||||||
|  |             stat["target_cls"] = cls | ||||||
|  |             if npr == 0: | ||||||
|  |                 if nl: | ||||||
|  |                     for k in self.stats.keys(): | ||||||
|  |                         self.stats[k].append(stat[k]) | ||||||
|  |                     if self.args.plots: | ||||||
|  |                         self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             # Predictions | ||||||
|  |             if self.args.single_cls: | ||||||
|  |                 pred[:, 5] = 0 | ||||||
|  |             predn = self._prepare_pred(pred, pbatch) | ||||||
|  |             stat["conf"] = predn[:, 4] | ||||||
|  |             stat["pred_cls"] = predn[:, 5] | ||||||
|  | 
 | ||||||
|  |             # Evaluate | ||||||
|  |             if nl: | ||||||
|  |                 stat["tp"] = self._process_batch(predn, bbox, cls) | ||||||
|  |                 if self.args.plots: | ||||||
|  |                     self.confusion_matrix.process_batch(predn, bbox, cls) | ||||||
|  |             for k in self.stats.keys(): | ||||||
|  |                 self.stats[k].append(stat[k]) | ||||||
|  | 
 | ||||||
|  |             # Save | ||||||
|  |             if self.args.save_json: | ||||||
|  |                 self.pred_to_json(predn, batch["im_file"][si]) | ||||||
|  |             if self.args.save_txt: | ||||||
|  |                 file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' | ||||||
|  |                 self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) | ||||||
|  | 
 | ||||||
|  |     def finalize_metrics(self, *args, **kwargs): | ||||||
|  |         """Set final values for metrics speed and confusion matrix.""" | ||||||
|  |         self.metrics.speed = self.speed | ||||||
|  |         self.metrics.confusion_matrix = self.confusion_matrix | ||||||
|  | 
 | ||||||
|  |     def get_stats(self): | ||||||
|  |         """Returns metrics statistics and results dictionary.""" | ||||||
|  |         stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy | ||||||
|  |         if len(stats) and stats["tp"].any(): | ||||||
|  |             self.metrics.process(**stats) | ||||||
|  |         self.nt_per_class = np.bincount( | ||||||
|  |             stats["target_cls"].astype(int), minlength=self.nc | ||||||
|  |         )  # number of targets per class | ||||||
|  |         return self.metrics.results_dict | ||||||
|  | 
 | ||||||
|  |     def print_results(self): | ||||||
|  |         """Prints training/validation set metrics per class.""" | ||||||
|  |         pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format | ||||||
|  |         LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) | ||||||
|  |         if self.nt_per_class.sum() == 0: | ||||||
|  |             LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") | ||||||
|  | 
 | ||||||
|  |         # Print results per class | ||||||
|  |         if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): | ||||||
|  |             for i, c in enumerate(self.metrics.ap_class_index): | ||||||
|  |                 LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) | ||||||
|  | 
 | ||||||
|  |         if self.args.plots: | ||||||
|  |             for normalize in True, False: | ||||||
|  |                 self.confusion_matrix.plot( | ||||||
|  |                     save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |     def _process_batch(self, detections, gt_bboxes, gt_cls): | ||||||
|  |         """ | ||||||
|  |         Return correct prediction matrix. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             detections (torch.Tensor): Tensor of shape [N, 6] representing detections. | ||||||
|  |                 Each detection is of the format: x1, y1, x2, y2, conf, class. | ||||||
|  |             labels (torch.Tensor): Tensor of shape [M, 5] representing labels. | ||||||
|  |                 Each label is of the format: class, x1, y1, x2, y2. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels. | ||||||
|  |         """ | ||||||
|  |         iou = box_iou(gt_bboxes, detections[:, :4]) | ||||||
|  |         return self.match_predictions(detections[:, 5], gt_cls, iou) | ||||||
|  | 
 | ||||||
|  |     def build_dataset(self, img_path, mode="val", batch=None): | ||||||
|  |         """ | ||||||
|  |         Build YOLO Dataset. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             img_path (str): Path to the folder containing images. | ||||||
|  |             mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. | ||||||
|  |             batch (int, optional): Size of batches, this is for `rect`. Defaults to None. | ||||||
|  |         """ | ||||||
|  |         return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride) | ||||||
|  | 
 | ||||||
|  |     def get_dataloader(self, dataset_path, batch_size): | ||||||
|  |         """Construct and return dataloader.""" | ||||||
|  |         dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") | ||||||
|  |         return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1)  # return dataloader | ||||||
|  | 
 | ||||||
|  |     def plot_val_samples(self, batch, ni): | ||||||
|  |         """Plot validation image samples.""" | ||||||
|  |         plot_images( | ||||||
|  |             batch["img"], | ||||||
|  |             batch["batch_idx"], | ||||||
|  |             batch["cls"].squeeze(-1), | ||||||
|  |             batch["bboxes"], | ||||||
|  |             paths=batch["im_file"], | ||||||
|  |             fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||||
|  |             names=self.names, | ||||||
|  |             on_plot=self.on_plot, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def plot_predictions(self, batch, preds, ni): | ||||||
|  |         """Plots predicted bounding boxes on input images and saves the result.""" | ||||||
|  |         plot_images( | ||||||
|  |             batch["img"], | ||||||
|  |             *output_to_target(preds, max_det=self.args.max_det), | ||||||
|  |             paths=batch["im_file"], | ||||||
|  |             fname=self.save_dir / f"val_batch{ni}_pred.jpg", | ||||||
|  |             names=self.names, | ||||||
|  |             on_plot=self.on_plot, | ||||||
|  |         )  # pred | ||||||
|  | 
 | ||||||
|  |     def save_one_txt(self, predn, save_conf, shape, file): | ||||||
|  |         """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" | ||||||
|  |         gn = torch.tensor(shape)[[1, 0, 1, 0]]  # normalization gain whwh | ||||||
|  |         for *xyxy, conf, cls in predn.tolist(): | ||||||
|  |             xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh | ||||||
|  |             line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format | ||||||
|  |             with open(file, "a") as f: | ||||||
|  |                 f.write(("%g " * len(line)).rstrip() % line + "\n") | ||||||
|  | 
 | ||||||
|  |     def pred_to_json(self, predn, filename): | ||||||
|  |         """Serialize YOLO predictions to COCO json format.""" | ||||||
|  |         stem = Path(filename).stem | ||||||
|  |         image_id = int(stem) if stem.isnumeric() else stem | ||||||
|  |         box = ops.xyxy2xywh(predn[:, :4])  # xywh | ||||||
|  |         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||||
|  |         for p, b in zip(predn.tolist(), box.tolist()): | ||||||
|  |             self.jdict.append( | ||||||
|  |                 { | ||||||
|  |                     "image_id": image_id, | ||||||
|  |                     "category_id": self.class_map[int(p[5])], | ||||||
|  |                     "bbox": [round(x, 3) for x in b], | ||||||
|  |                     "score": round(p[4], 5), | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     def eval_json(self, stats): | ||||||
|  |         """Evaluates YOLO output in JSON format and returns performance statistics.""" | ||||||
|  |         if self.args.save_json and self.is_coco and len(self.jdict): | ||||||
|  |             anno_json = self.data["path"] / "annotations/instances_val2017.json"  # annotations | ||||||
|  |             pred_json = self.save_dir / "predictions.json"  # predictions | ||||||
|  |             LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") | ||||||
|  |             try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb | ||||||
|  |                 check_requirements("pycocotools>=2.0.6") | ||||||
|  |                 from pycocotools.coco import COCO  # noqa | ||||||
|  |                 from pycocotools.cocoeval import COCOeval  # noqa | ||||||
|  | 
 | ||||||
|  |                 for x in anno_json, pred_json: | ||||||
|  |                     assert x.is_file(), f"{x} file not found" | ||||||
|  |                 anno = COCO(str(anno_json))  # init annotations api | ||||||
|  |                 pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path) | ||||||
|  |                 eval = COCOeval(anno, pred, "bbox") | ||||||
|  |                 if self.is_coco: | ||||||
|  |                     eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # images to eval | ||||||
|  |                 eval.evaluate() | ||||||
|  |                 eval.accumulate() | ||||||
|  |                 eval.summarize() | ||||||
|  |                 stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2]  # update mAP50-95 and mAP50 | ||||||
|  |             except Exception as e: | ||||||
|  |                 LOGGER.warning(f"pycocotools unable to run: {e}") | ||||||
|  |         return stats | ||||||
| @ -3,5 +3,6 @@ | |||||||
| from .predict import SegmentationPredictor | from .predict import SegmentationPredictor | ||||||
| from .train import SegmentationTrainer | from .train import SegmentationTrainer | ||||||
| from .val import SegmentationValidator | from .val import SegmentationValidator | ||||||
|  | from .pgt_train import PGTSegmentationTrainer | ||||||
| 
 | 
 | ||||||
| __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" | __all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer" | ||||||
|  | |||||||
							
								
								
									
										73
									
								
								ultralytics/models/yolo/segment/pgt_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								ultralytics/models/yolo/segment/pgt_train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,73 @@ | |||||||
|  | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
|  | 
 | ||||||
|  | from copy import copy | ||||||
|  | 
 | ||||||
|  | from ultralytics.models import yolo | ||||||
|  | from ultralytics.nn.tasks import SegmentationModel, DetectionModel | ||||||
|  | from ultralytics.utils import DEFAULT_CFG, RANK | ||||||
|  | from ultralytics.utils.plotting import plot_images, plot_results | ||||||
|  | from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel | ||||||
|  | from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator | ||||||
|  | 
 | ||||||
|  | class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer): | ||||||
|  |     """ | ||||||
|  |     A class extending the DetectionTrainer class for training based on a segmentation model. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         from ultralytics.models.yolo.segment import SegmentationTrainer | ||||||
|  | 
 | ||||||
|  |         args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3) | ||||||
|  |         trainer = SegmentationTrainer(overrides=args) | ||||||
|  |         trainer.train() | ||||||
|  |         ``` | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | ||||||
|  |         """Initialize a SegmentationTrainer object with given arguments.""" | ||||||
|  |         if overrides is None: | ||||||
|  |             overrides = {} | ||||||
|  |         overrides["task"] = "segment" | ||||||
|  |         super().__init__(cfg, overrides, _callbacks) | ||||||
|  | 
 | ||||||
|  |     def get_model(self, cfg=None, weights=None, verbose=True): | ||||||
|  |         """Return SegmentationModel initialized with specified config and weights.""" | ||||||
|  |         if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']: | ||||||
|  |             model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||||
|  |         else: | ||||||
|  |             model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) | ||||||
|  |         if weights: | ||||||
|  |             model.load(weights) | ||||||
|  | 
 | ||||||
|  |         return model | ||||||
|  | 
 | ||||||
|  |     def get_validator(self): | ||||||
|  |         """Return an instance of SegmentationValidator for validation of YOLO model.""" | ||||||
|  |          | ||||||
|  |         if self.args.model in ['yolov10n.pt', 'yolov10m.pt', 'yolov10x.pt', 'yolov10s.pt', 'yolov10b.pt', 'yolov10l.pt']: | ||||||
|  |             self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss", | ||||||
|  |             return YOLOv10PGTDetectionValidator( | ||||||
|  |                 self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" | ||||||
|  |             return yolo.segment.SegmentationValidator( | ||||||
|  |                 self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     def plot_training_samples(self, batch, ni): | ||||||
|  |         """Creates a plot of training sample images with labels and box coordinates.""" | ||||||
|  |         plot_images( | ||||||
|  |             batch["img"], | ||||||
|  |             batch["batch_idx"], | ||||||
|  |             batch["cls"].squeeze(-1), | ||||||
|  |             batch["bboxes"], | ||||||
|  |             masks=batch["masks"], | ||||||
|  |             paths=batch["im_file"], | ||||||
|  |             fname=self.save_dir / f"train_batch{ni}.jpg", | ||||||
|  |             on_plot=self.on_plot, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def plot_metrics(self): | ||||||
|  |         """Plots training/val metrics.""" | ||||||
|  |         plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png | ||||||
| @ -1,5 +1,5 @@ | |||||||
| from ultralytics.engine.model import Model | from ultralytics.engine.model import Model | ||||||
| from ultralytics.nn.tasks import YOLOv10DetectionModel | from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel | ||||||
| from .val import YOLOv10DetectionValidator | from .val import YOLOv10DetectionValidator | ||||||
| from .predict import YOLOv10DetectionPredictor | from .predict import YOLOv10DetectionPredictor | ||||||
| from .train import YOLOv10DetectionTrainer | from .train import YOLOv10DetectionTrainer | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| from ultralytics.models.yolo.detect import DetectionValidator | from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator | ||||||
| from ultralytics.utils import ops | from ultralytics.utils import ops | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| @ -7,6 +7,27 @@ class YOLOv10DetectionValidator(DetectionValidator): | |||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.args.save_json |= self.is_coco |         self.args.save_json |= self.is_coco | ||||||
| 
 | 
 | ||||||
|  |     def postprocess(self, preds): | ||||||
|  |         if isinstance(preds, dict): | ||||||
|  |             preds = preds["one2one"] | ||||||
|  | 
 | ||||||
|  |         if isinstance(preds, (list, tuple)): | ||||||
|  |             preds = preds[0] | ||||||
|  |          | ||||||
|  |         # Acknowledgement: Thanks to sanha9999 in #190 and #181! | ||||||
|  |         if preds.shape[-1] == 6: | ||||||
|  |             return preds | ||||||
|  |         else: | ||||||
|  |             preds = preds.transpose(-1, -2) | ||||||
|  |             boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc) | ||||||
|  |             bboxes = ops.xywh2xyxy(boxes) | ||||||
|  |             return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) | ||||||
|  |          | ||||||
|  | class YOLOv10PGTDetectionValidator(PGTDetectionValidator): | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.args.save_json |= self.is_coco | ||||||
|  | 
 | ||||||
|     def postprocess(self, preds): |     def postprocess(self, preds): | ||||||
|         if isinstance(preds, dict): |         if isinstance(preds, dict): | ||||||
|             preds = preds["one2one"] |             preds = preds["one2one"] | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from ultralytics.nn.modules import ( | |||||||
| ) | ) | ||||||
| from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load | from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load | ||||||
| from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml | from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml | ||||||
| from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss | from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss | ||||||
| from ultralytics.utils.plotting import feature_visualization | from ultralytics.utils.plotting import feature_visualization | ||||||
| from ultralytics.utils.torch_utils import ( | from ultralytics.utils.torch_utils import ( | ||||||
|     fuse_conv_and_bn, |     fuse_conv_and_bn, | ||||||
| @ -651,6 +651,10 @@ class WorldModel(DetectionModel): | |||||||
| class YOLOv10DetectionModel(DetectionModel): | class YOLOv10DetectionModel(DetectionModel): | ||||||
|     def init_criterion(self): |     def init_criterion(self): | ||||||
|         return v10DetectLoss(self) |         return v10DetectLoss(self) | ||||||
|  |      | ||||||
|  | class YOLOv10PGTDetectionModel(DetectionModel): | ||||||
|  |     def init_criterion(self): | ||||||
|  |         return v10PGTDetectLoss(self) | ||||||
| 
 | 
 | ||||||
| class Ensemble(nn.ModuleList): | class Ensemble(nn.ModuleList): | ||||||
|     """Ensemble of models.""" |     """Ensemble of models.""" | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh | |||||||
| from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors | from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors | ||||||
| from .metrics import bbox_iou, probiou | from .metrics import bbox_iou, probiou | ||||||
| from .tal import bbox2dist | from .tal import bbox2dist | ||||||
| 
 | from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn | ||||||
| 
 | 
 | ||||||
| class VarifocalLoss(nn.Module): | class VarifocalLoss(nn.Module): | ||||||
|     """ |     """ | ||||||
| @ -725,3 +725,30 @@ class v10DetectLoss: | |||||||
|         one2one = preds["one2one"] |         one2one = preds["one2one"] | ||||||
|         loss_one2one = self.one2one(one2one, batch) |         loss_one2one = self.one2one(one2one, batch) | ||||||
|         return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) |         return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) | ||||||
|  | 
 | ||||||
|  | class v10PGTDetectLoss: | ||||||
|  |     def __init__(self, model): | ||||||
|  |         self.one2many = v8DetectionLoss(model, tal_topk=10) | ||||||
|  |         self.one2one = v8DetectionLoss(model, tal_topk=1) | ||||||
|  |      | ||||||
|  |     def __call__(self, preds, batch): | ||||||
|  |         batch['img'] = batch['img'].requires_grad_(True) | ||||||
|  |         one2many = preds["one2many"] | ||||||
|  |         loss_one2many = self.one2many(one2many, batch) | ||||||
|  |         one2one = preds["one2one"] | ||||||
|  |         loss_one2one = self.one2one(one2one, batch) | ||||||
|  | 
 | ||||||
|  |         loss = loss_one2many[0] + loss_one2one[0] | ||||||
|  |          | ||||||
|  |         smask = get_dist_reg(batch['img'], batch['masks']) | ||||||
|  | 
 | ||||||
|  |         grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0] | ||||||
|  |         grad = torch.abs(grad) | ||||||
|  | 
 | ||||||
|  |         pgt_coeff = 3.0 | ||||||
|  |         plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff) | ||||||
|  |         # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0))) | ||||||
|  |         loss += plaus_loss | ||||||
|  |          | ||||||
|  |         return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0))) | ||||||
|  |      | ||||||
							
								
								
									
										818
									
								
								ultralytics/utils/plaus_functs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										818
									
								
								ultralytics/utils/plaus_functs.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,818 @@ | |||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  | # from plot_functs import *  | ||||||
|  | from .plot_functs import normalize_tensor, overlay_mask, imshow | ||||||
|  | import math    | ||||||
|  | import time | ||||||
|  | import matplotlib.path as mplPath | ||||||
|  | from matplotlib.path import Path | ||||||
|  | # from utils.general import non_max_suppression, xyxy2xywh, scale_coords | ||||||
|  | from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression | ||||||
|  | from .metrics import bbox_iou | ||||||
|  | import torchvision.transforms as T | ||||||
|  | 
 | ||||||
|  | def plaus_loss_fn(grad, smask, pgt_coeff): | ||||||
|  |     ################## Compute the PGT Loss ################## | ||||||
|  |     # Positive regularization term for incentivizing pixels near the target to have high attribution | ||||||
|  |     dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask | ||||||
|  |     # Negative regularization term for incentivizing pixels far from the target to have low attribution | ||||||
|  |     dist_attr_neg = attr_reg(grad, smask) | ||||||
|  |     # Calculate plausibility regularization term | ||||||
|  |     # dist_reg = dist_attr_pos - dist_attr_neg | ||||||
|  |     dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad))) | ||||||
|  |     plaus_reg = (((1.0 + dist_reg) / 2.0)) | ||||||
|  |     # Calculate plausibility loss | ||||||
|  |     plaus_loss = (1 - plaus_reg) * pgt_coeff | ||||||
|  |     return plaus_loss | ||||||
|  | 
 | ||||||
|  | def get_dist_reg(images, seg_mask): | ||||||
|  |     seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device) | ||||||
|  |     seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1) | ||||||
|  |     seg_mask[seg_mask > 0] = 1.0 | ||||||
|  |      | ||||||
|  |     smask = torch.zeros_like(seg_mask) | ||||||
|  |     sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)] | ||||||
|  |     for k_it, sigma in enumerate(sigmas): | ||||||
|  |         # Apply Gaussian blur to the mask | ||||||
|  |         kernel_size = int(sigma + 50) | ||||||
|  |         if kernel_size % 2 == 0: | ||||||
|  |             kernel_size += 1 | ||||||
|  |         seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask) | ||||||
|  |         if torch.max(seg_mask1) > 1.0: | ||||||
|  |             seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min()) | ||||||
|  |         smask = torch.max(smask, seg_mask1) | ||||||
|  |     return smask | ||||||
|  | 
 | ||||||
|  | def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False): | ||||||
|  |     """ | ||||||
|  |     Compute the gradient of an image with respect to a given tensor. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         img (torch.Tensor): The input image tensor. | ||||||
|  |         grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed. | ||||||
|  |         norm (bool, optional): Whether to normalize the gradient. Defaults to True. | ||||||
|  |         absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. | ||||||
|  |         grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True. | ||||||
|  |         keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: The computed attribution map. | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  |     if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])): | ||||||
|  |         grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True) | ||||||
|  |     else: | ||||||
|  |         grad_wrt_outputs = None | ||||||
|  |     attribution_map = torch.autograd.grad(grad_wrt, img,  | ||||||
|  |                                     grad_outputs=grad_wrt_outputs,  | ||||||
|  |                                     create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly | ||||||
|  |                                     )[0] | ||||||
|  |     if absolute: | ||||||
|  |         attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients | ||||||
|  |     if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval | ||||||
|  |         attribution_map = torch.sum(attribution_map, 1, keepdim=True) | ||||||
|  |     if norm: | ||||||
|  |         if keepmean: | ||||||
|  |             attmean = torch.mean(attribution_map) | ||||||
|  |             attmin = torch.min(attribution_map) | ||||||
|  |             attmax = torch.max(attribution_map) | ||||||
|  |         attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch | ||||||
|  |         if keepmean: | ||||||
|  |             attribution_map -= attribution_map.mean() | ||||||
|  |             attribution_map += (attmean / (attmax - attmin)) | ||||||
|  |          | ||||||
|  |     return attribution_map | ||||||
|  | 
 | ||||||
|  | def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False): | ||||||
|  |     """ | ||||||
|  |     Generate Gaussian noise based on the input image. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         img (torch.Tensor): Input image. | ||||||
|  |         grad_wrt: Gradient with respect to the input image. | ||||||
|  |         norm (bool, optional): Whether to normalize the generated noise. Defaults to True. | ||||||
|  |         absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True. | ||||||
|  |         grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True. | ||||||
|  |         keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: Generated Gaussian noise. | ||||||
|  |     """ | ||||||
|  |      | ||||||
|  |     gaussian_noise = torch.randn_like(img) | ||||||
|  |      | ||||||
|  |     if absolute: | ||||||
|  |         gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients | ||||||
|  |     if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval | ||||||
|  |         gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True) | ||||||
|  |     if norm: | ||||||
|  |         if keepmean: | ||||||
|  |             attmean = torch.mean(gaussian_noise) | ||||||
|  |             attmin = torch.min(gaussian_noise) | ||||||
|  |             attmax = torch.max(gaussian_noise) | ||||||
|  |         gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch | ||||||
|  |         if keepmean: | ||||||
|  |             gaussian_noise -= gaussian_noise.mean() | ||||||
|  |             gaussian_noise += (attmean / (attmax - attmin)) | ||||||
|  |          | ||||||
|  |     return gaussian_noise | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): | ||||||
|  |     # TODO: Remove imgs from this function and only take it as input if debug is True | ||||||
|  |     """ | ||||||
|  |     Calculates the plausibility score based on the given inputs. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         imgs (torch.Tensor): The input images. | ||||||
|  |         targets_out (torch.Tensor): The output targets. | ||||||
|  |         attr (torch.Tensor): The attribute tensor. | ||||||
|  |         debug (bool, optional): Whether to enable debug mode. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: The plausibility score. | ||||||
|  |     """ | ||||||
|  |     # # if imgs is None: | ||||||
|  |     # #     imgs = torch.zeros_like(attr) | ||||||
|  |     # # with torch.no_grad(): | ||||||
|  |     # target_inds = targets_out[:, 0].int() | ||||||
|  |     # xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] | ||||||
|  |     # num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     # # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     # xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() | ||||||
|  |     # co = xyxy_corners | ||||||
|  |     # if corners: | ||||||
|  |     #     co = targets_out[:, 2:6].int() | ||||||
|  |     # coords_map = torch.zeros_like(attr, dtype=torch.bool) | ||||||
|  |     # # rows = np.arange(co.shape[0]) | ||||||
|  |     # x1, x2 = co[:,1], co[:,3] | ||||||
|  |     # y1, y2 = co[:,0], co[:,2] | ||||||
|  |      | ||||||
|  |     # for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop | ||||||
|  |     #     coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True | ||||||
|  | 
 | ||||||
|  |     if torch.isnan(attr).any(): | ||||||
|  |         attr = torch.nan_to_num(attr, nan=0.0) | ||||||
|  |      | ||||||
|  |     coords_map = get_bbox_map(targets_out, attr) | ||||||
|  |     plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr))) | ||||||
|  | 
 | ||||||
|  |     if debug: | ||||||
|  |         for i in range(len(coords_map)): | ||||||
|  |             coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) | ||||||
|  |             test_bbox = torch.zeros_like(imgs[i]) | ||||||
|  |             test_bbox[coords_map3ch] = imgs[i][coords_map3ch] | ||||||
|  |             imshow(test_bbox, save_path='figs/test_bbox') | ||||||
|  |             if imgs is None: | ||||||
|  |                 imgs = torch.zeros_like(attr) | ||||||
|  |             imshow(imgs[i], save_path='figs/im0') | ||||||
|  |             imshow(attr[i], save_path='figs/attr') | ||||||
|  |      | ||||||
|  |     # with torch.no_grad(): | ||||||
|  |     # # att_select = attr[coords_map] | ||||||
|  |     # att_select = attr * coords_map.to(torch.float32) | ||||||
|  |     # att_total = attr | ||||||
|  |      | ||||||
|  |     # IoU_num = torch.sum(att_select) | ||||||
|  |     # IoU_denom = torch.sum(att_total) | ||||||
|  |      | ||||||
|  |     # IoU_ = (IoU_num / IoU_denom) | ||||||
|  |     # plaus_score = IoU_ | ||||||
|  | 
 | ||||||
|  |     # # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr))) | ||||||
|  |      | ||||||
|  |     return plaus_score | ||||||
|  | 
 | ||||||
|  | def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7): | ||||||
|  |     # TODO: Remove imgs from this function and only take it as input if debug is True | ||||||
|  |     """ | ||||||
|  |     Calculates the plausibility score based on the given inputs. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         imgs (torch.Tensor): The input images. | ||||||
|  |         targets_out (torch.Tensor): The output targets. | ||||||
|  |         attr (torch.Tensor): The attribute tensor. | ||||||
|  |         debug (bool, optional): Whether to enable debug mode. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: The plausibility score. | ||||||
|  |     """ | ||||||
|  |     # if imgs is None: | ||||||
|  |     #     imgs = torch.zeros_like(attr) | ||||||
|  |     # with torch.no_grad(): | ||||||
|  |     target_inds = targets_out[:, 0].int() | ||||||
|  |     xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] | ||||||
|  |     num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() | ||||||
|  |     co = xyxy_corners | ||||||
|  |     if corners: | ||||||
|  |         co = targets_out[:, 2:6].int() | ||||||
|  |     coords_map = torch.zeros_like(attr, dtype=torch.bool) | ||||||
|  |     # rows = np.arange(co.shape[0]) | ||||||
|  |     x1, x2 = co[:,1], co[:,3] | ||||||
|  |     y1, y2 = co[:,0], co[:,2] | ||||||
|  |      | ||||||
|  |     for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop | ||||||
|  |         coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True | ||||||
|  | 
 | ||||||
|  |     if torch.isnan(attr).any(): | ||||||
|  |         attr = torch.nan_to_num(attr, nan=0.0) | ||||||
|  |     if debug: | ||||||
|  |         for i in range(len(coords_map)): | ||||||
|  |             coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0) | ||||||
|  |             test_bbox = torch.zeros_like(imgs[i]) | ||||||
|  |             test_bbox[coords_map3ch] = imgs[i][coords_map3ch] | ||||||
|  |             imshow(test_bbox, save_path='figs/test_bbox') | ||||||
|  |             imshow(imgs[i], save_path='figs/im0') | ||||||
|  |             imshow(attr[i], save_path='figs/attr') | ||||||
|  |      | ||||||
|  |     # att_select = attr[coords_map] | ||||||
|  |     # with torch.no_grad(): | ||||||
|  |     # IoU_num = (torch.sum(attr[coords_map])) | ||||||
|  |     # IoU_denom = torch.sum(attr) | ||||||
|  |     # IoU_ = (IoU_num / (IoU_denom)) | ||||||
|  |      | ||||||
|  |     # IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map]) | ||||||
|  |     co = (xyxy_batch * num_pixels).int() | ||||||
|  |     x1 = co[:,1] + 1 | ||||||
|  |     y1 = co[:,0] + 1 | ||||||
|  |     # with torch.no_grad(): | ||||||
|  |     attr_ = torch.sum(attr, 1, keepdim=True) | ||||||
|  |     corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device) | ||||||
|  |     for ic in range(co.shape[0]): | ||||||
|  |         attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]] | ||||||
|  |         attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:] | ||||||
|  |         attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]] | ||||||
|  |         attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:] | ||||||
|  | 
 | ||||||
|  |         x_0, y_0 = max_indices_2d(attr0[0]) | ||||||
|  |         x_1, y_1 = max_indices_2d(attr1[0]) | ||||||
|  |         x_2, y_2 = max_indices_2d(attr2[0]) | ||||||
|  |         x_3, y_3 = max_indices_2d(attr3[0]) | ||||||
|  | 
 | ||||||
|  |         y_1 += y1[ic] | ||||||
|  |         x_2 += x1[ic] | ||||||
|  |         x_3 += x1[ic] | ||||||
|  |         y_3 += y1[ic] | ||||||
|  | 
 | ||||||
|  |         max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2], | ||||||
|  |                                     torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3], | ||||||
|  |                                     torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2], | ||||||
|  |                                     torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]]) | ||||||
|  |         if corners_attr is None: | ||||||
|  |             corners_attr = max_corners | ||||||
|  |         else: | ||||||
|  |             corners_attr = torch.cat([corners_attr, max_corners], dim=0) | ||||||
|  |         # corners_attr[ic] = max_corners | ||||||
|  |         # corners_attr = attr[:,0,:4,0] | ||||||
|  |     corners_attr = corners_attr.view(-1, 4) | ||||||
|  |     # corners_attr = torch.stack(corners_attr, dim=0) | ||||||
|  |     IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU') | ||||||
|  |     plaus_score = IoU_.mean() | ||||||
|  | 
 | ||||||
|  |     return plaus_score | ||||||
|  | 
 | ||||||
|  | def max_indices_2d(x_inp): | ||||||
|  |     # values, indices = x.reshape(x.size(0), -1).max(dim=-1) | ||||||
|  |     torch.max(x_inp,) | ||||||
|  |     index = torch.argmax(x_inp) | ||||||
|  |     x = index // x_inp.shape[1] | ||||||
|  |     y = index % x_inp.shape[1] | ||||||
|  |     # x, y = divmod(index.item(), x_inp.shape[1]) | ||||||
|  | 
 | ||||||
|  |     return torch.cat([x.unsqueeze(0), y.unsqueeze(0)]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def point_in_polygon(poly, grid): | ||||||
|  |     # t0 = time.time() | ||||||
|  |     num_points = poly.shape[0] | ||||||
|  |     j = num_points - 1 | ||||||
|  |     oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) | ||||||
|  |     for i in range(num_points): | ||||||
|  |         cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1]) | ||||||
|  |         cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1]) | ||||||
|  |         cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1]) | ||||||
|  |         oddNodes = oddNodes ^ (cond1 | cond2) & cond3 | ||||||
|  |         j = i | ||||||
|  |     # t1 = time.time() | ||||||
|  |     # print(f'point in polygon time: {t1-t0}') | ||||||
|  |     return oddNodes | ||||||
|  |      | ||||||
|  | def point_in_polygon_gpu(poly, grid): | ||||||
|  |     num_points = poly.shape[0] | ||||||
|  |     i = torch.arange(num_points) | ||||||
|  |     j = (i - 1) % num_points | ||||||
|  |     # Expand dimensions | ||||||
|  |     # t0 = time.time() | ||||||
|  |     poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0]) | ||||||
|  |     # t1 = time.time() | ||||||
|  |     cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1]) | ||||||
|  |     cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1]) | ||||||
|  |     cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1]) | ||||||
|  |     # t2 = time.time() | ||||||
|  |     oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool) | ||||||
|  |     cond = (cond1 | cond2) & cond3 | ||||||
|  |     # t3 = time.time() | ||||||
|  |     # efficiently perform xor using gpu and avoiding cpu as much as possible | ||||||
|  |     c = [] | ||||||
|  |     while len(cond) > 1:  | ||||||
|  |         if len(cond) % 2 == 1: # odd number of elements | ||||||
|  |             c.append(cond[-1]) | ||||||
|  |             cond = cond[:-1] | ||||||
|  |         cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):]) | ||||||
|  |     for c_ in c: | ||||||
|  |         cond = torch.bitwise_xor(cond, c_) | ||||||
|  |     oddNodes = cond | ||||||
|  |     # t4 = time.time() | ||||||
|  |     # for c in cond: | ||||||
|  |     #     oddNodes = oddNodes ^ c | ||||||
|  |     # print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} |  bitwise xor time: {t4-t3}') | ||||||
|  |     # print(f'point in polygon time gpu: {t4-t0}') | ||||||
|  |     # oddNodes = oddNodes ^ (cond1 | cond2) & cond3 | ||||||
|  |     return oddNodes | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def bitmap_for_polygon(poly, h, w): | ||||||
|  |     y = torch.arange(h).to(poly.device).float() | ||||||
|  |     x = torch.arange(w).to(poly.device).float() | ||||||
|  |     grid_y, grid_x = torch.meshgrid(y, x) | ||||||
|  |     grid = torch.stack((grid_x, grid_y), dim=-1) | ||||||
|  |     bitmap = point_in_polygon(poly, grid) | ||||||
|  |     return bitmap.unsqueeze(0) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def corners_coords(center_xywh): | ||||||
|  |     center_x, center_y, w, h = center_xywh | ||||||
|  |     x = center_x - w/2 | ||||||
|  |     y = center_y - h/2 | ||||||
|  |     return torch.tensor([x, y, x+w, y+h]) | ||||||
|  | 
 | ||||||
|  | def corners_coords_batch(center_xywh): | ||||||
|  |     center_x, center_y = center_xywh[:,0], center_xywh[:,1] | ||||||
|  |     w, h = center_xywh[:,2], center_xywh[:,3] | ||||||
|  |     x = center_x - w/2 | ||||||
|  |     y = center_y - h/2 | ||||||
|  |     return torch.stack([x, y, x+w, y+h], dim=1) | ||||||
|  |      | ||||||
|  | def normalize_batch(x): | ||||||
|  |     """ | ||||||
|  |     Normalize a batch of tensors along each channel. | ||||||
|  |      | ||||||
|  |     Args: | ||||||
|  |         x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). | ||||||
|  |          | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: Normalized tensor of the same shape as the input. | ||||||
|  |     """ | ||||||
|  |     mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) | ||||||
|  |     maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device) | ||||||
|  |     for i in range(x.shape[0]): | ||||||
|  |         mins[i] = x[i].min() | ||||||
|  |         maxs[i] = x[i].max() | ||||||
|  |     x_ = (x - mins) / (maxs - mins) | ||||||
|  |      | ||||||
|  |     return x_ | ||||||
|  | 
 | ||||||
|  | def get_detections(model_clone, img): | ||||||
|  |     """ | ||||||
|  |     Get detections from a model given an input image and targets. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         model (nn.Module): The model to use for detection. | ||||||
|  |         img (torch.Tensor): The input image tensor. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: The detected bounding boxes. | ||||||
|  |     """ | ||||||
|  |     model_clone.eval() # Set model to evaluation mode | ||||||
|  |     # Run inference | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         det_out, out = model_clone(img) | ||||||
|  |      | ||||||
|  |     # model_.train() | ||||||
|  |     del img  | ||||||
|  |      | ||||||
|  |     return det_out, out | ||||||
|  | 
 | ||||||
|  | def get_labels(det_out, imgs, targets, opt): | ||||||
|  |     ###################### Get predicted labels ######################  | ||||||
|  |     nb, _, height, width = imgs.shape  # batch size, channels, height, width  | ||||||
|  |     targets_ = targets.clone()  | ||||||
|  |     targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device)  # to pixels | ||||||
|  |     lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else []  # for autolabelling | ||||||
|  |     o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True) | ||||||
|  |     pred_labels = []  | ||||||
|  |     for si, pred in enumerate(o): | ||||||
|  |         labels = targets_[targets_[:, 0] == si, 1:] | ||||||
|  |         nl = len(labels)  | ||||||
|  |         predn = pred.clone() | ||||||
|  |         # Get the indices that sort the values in column 5 in ascending order | ||||||
|  |         sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True) | ||||||
|  |         # Apply the sorting indices to the tensor | ||||||
|  |         sorted_pred = predn[sort_indices] | ||||||
|  |         # Remove predictions with less than 0.1 confidence | ||||||
|  |         n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1 | ||||||
|  |         sorted_pred = sorted_pred[:n_conf] | ||||||
|  |         new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si | ||||||
|  |         preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1) | ||||||
|  |         preds[:, 2:] = xyxy2xywh(preds[:, 2:])  # xywh | ||||||
|  |         gn = torch.tensor([width, height])[[1, 0, 1, 0]]  # normalization gain whwh | ||||||
|  |         preds[:, 2:] /= gn.to(imgs.device)  # from pixels | ||||||
|  |         pred_labels.append(preds) | ||||||
|  |     pred_labels = torch.cat(pred_labels, 0).to(imgs.device) | ||||||
|  |      | ||||||
|  |     return pred_labels | ||||||
|  |     ################################################################## | ||||||
|  | 
 | ||||||
|  | from torchvision.utils import make_grid | ||||||
|  | 
 | ||||||
|  | def get_center_coords(attr): | ||||||
|  |     img_tensor = img_tensor / img_tensor.max() | ||||||
|  | 
 | ||||||
|  |     # Define a brightness threshold | ||||||
|  |     threshold = 0.95 | ||||||
|  | 
 | ||||||
|  |     # Create a binary mask of the bright pixels | ||||||
|  |     mask = img_tensor > threshold | ||||||
|  | 
 | ||||||
|  |     # Get the coordinates of the bright pixels | ||||||
|  |     y_coords, x_coords = torch.where(mask) | ||||||
|  | 
 | ||||||
|  |     # Calculate the centroid of the bright pixels | ||||||
|  |     centroid_x = x_coords.float().mean().item() | ||||||
|  |     centroid_y = y_coords.float().mean().item() | ||||||
|  | 
 | ||||||
|  |     print(f'The central bright point is at ({centroid_x}, {centroid_y})') | ||||||
|  |      | ||||||
|  |     return | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False): | ||||||
|  |     """ | ||||||
|  |     Compute the distance grids from each pixel to the target coordinates. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         attr (torch.Tensor): Attribution maps. | ||||||
|  |         targets (torch.Tensor): Target coordinates. | ||||||
|  |         focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5. | ||||||
|  |         debug (bool, optional): Whether to visualize debug information. Defaults to False. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: Distance grids. | ||||||
|  |     """ | ||||||
|  |      | ||||||
|  |     # Assign the height and width of the input tensor to variables | ||||||
|  |     height, width = attr.shape[-1], attr.shape[-2] | ||||||
|  |      | ||||||
|  |     # attr = torch.abs(attr) # Take absolute values of gradients | ||||||
|  |     # attr = normalize_batch(attr) # Normalize attribution maps per image in batch | ||||||
|  | 
 | ||||||
|  |     # Create a grid of indices | ||||||
|  |     xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device) | ||||||
|  |     idx_grid = torch.stack((xx, yy), dim=-1).float() | ||||||
|  |      | ||||||
|  |     # Expand the grid to match the batch size | ||||||
|  |     idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1) | ||||||
|  |      | ||||||
|  |     # Initialize a list to store the distance grids | ||||||
|  |     dist_grids_ = [[]] * attr.shape[0] | ||||||
|  | 
 | ||||||
|  |     # Loop over batches | ||||||
|  |     for j in range(attr.shape[0]): | ||||||
|  |         # Get the rows where the first column is the current unique value | ||||||
|  |         rows = targets[targets[:, 0] == j] | ||||||
|  |          | ||||||
|  |         if len(rows) != 0:  | ||||||
|  |             # Create a tensor for the target coordinates | ||||||
|  |             xy = rows[:,2:4] # y, x | ||||||
|  |             # Flip the x and y coordinates and scale them to the image size | ||||||
|  |             xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y | ||||||
|  |             xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)  | ||||||
|  |              | ||||||
|  |             # Compute the Euclidean distance from each pixel to the target coordinates | ||||||
|  |             dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1) | ||||||
|  | 
 | ||||||
|  |             # Pick the closest distance to any target for each pixel  | ||||||
|  |             dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)  | ||||||
|  |             dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_ | ||||||
|  |         else: | ||||||
|  |             # Set grid to zero if no targets are present | ||||||
|  |             dist_grid = torch.zeros_like(attr[j]) | ||||||
|  |              | ||||||
|  |         dist_grids_[j] = dist_grid | ||||||
|  |     # Convert the list of distance grids to a tensor for faster computation | ||||||
|  |     dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff | ||||||
|  |     if torch.isnan(dist_grids).any(): | ||||||
|  |         dist_grids = torch.nan_to_num(dist_grids, nan=0.0) | ||||||
|  | 
 | ||||||
|  |     if debug: | ||||||
|  |         for i in range(len(dist_grids)): | ||||||
|  |             if ((i % 8) == 0): | ||||||
|  |                 grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0) | ||||||
|  |                 imshow(grid_show, save_path='figs/dist_grids') | ||||||
|  |                 if imgs is None: | ||||||
|  |                     imgs = torch.zeros_like(attr) | ||||||
|  |                 imshow(imgs[i], save_path='figs/im0') | ||||||
|  |                 img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75)) | ||||||
|  |                 imshow(img_overlay, save_path='figs/dist_grid_overlay') | ||||||
|  |                 weighted_attr = (dist_grids[i] * attr[i]) | ||||||
|  |                 imshow(weighted_attr, save_path='figs/weighted_attr') | ||||||
|  |                 imshow(attr[i], save_path='figs/attr') | ||||||
|  | 
 | ||||||
|  |     return dist_grids | ||||||
|  | 
 | ||||||
|  | def attr_reg(attribution_map, distance_map): | ||||||
|  | 
 | ||||||
|  |     # dist_attr = distance_map * attribution_map  | ||||||
|  |     dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))  | ||||||
|  |     # del distance_map, attribution_map | ||||||
|  |     return dist_attr | ||||||
|  | 
 | ||||||
|  | def get_bbox_map(targets_out, attr, corners=False): | ||||||
|  |     target_inds = targets_out[:, 0].int() | ||||||
|  |     xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num] | ||||||
|  |     num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1)) | ||||||
|  |     xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int() | ||||||
|  |     co = xyxy_corners | ||||||
|  |     if corners: | ||||||
|  |         co = targets_out[:, 2:6].int() | ||||||
|  |     coords_map = torch.zeros_like(attr, dtype=torch.bool) | ||||||
|  |     # rows = np.arange(co.shape[0]) | ||||||
|  |     x1, x2 = co[:,1], co[:,3] | ||||||
|  |     y1, y2 = co[:,0], co[:,2] | ||||||
|  |      | ||||||
|  |     for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop | ||||||
|  |         coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True | ||||||
|  |      | ||||||
|  |     bbox_map = coords_map.to(torch.float32) | ||||||
|  | 
 | ||||||
|  |     return bbox_map | ||||||
|  | ######################################## BCE ####################################### | ||||||
|  | def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False): | ||||||
|  |     # if imgs is None: | ||||||
|  |     #     imgs = torch.zeros_like(attribution_map) | ||||||
|  |     # Calculate Plausibility IoU with attribution maps | ||||||
|  |     # attribution_map.retains_grad = True | ||||||
|  |     if not only_loss: | ||||||
|  |         plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs) | ||||||
|  |     else: | ||||||
|  |         plaus_score = torch.tensor(0.0) | ||||||
|  |      | ||||||
|  |     # attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch | ||||||
|  | 
 | ||||||
|  |     # Calculate distance regularization | ||||||
|  |     distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff) | ||||||
|  |     # distance_map = torch.ones_like(attribution_map) | ||||||
|  |      | ||||||
|  |     if opt.dist_x_bbox: | ||||||
|  |         bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool) | ||||||
|  |         distance_map[bbox_map] = 0.0 | ||||||
|  |         # distance_map = distance_map * (1 - bbox_map) | ||||||
|  | 
 | ||||||
|  |     # Positive regularization term for incentivizing pixels near the target to have high attribution | ||||||
|  |     dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map)) | ||||||
|  |     # Negative regularization term for incentivizing pixels far from the target to have low attribution | ||||||
|  |     dist_attr_neg = attr_reg(attribution_map, distance_map) | ||||||
|  |     # Calculate plausibility regularization term | ||||||
|  |     # dist_reg = dist_attr_pos - dist_attr_neg | ||||||
|  |     dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map))) | ||||||
|  |     # dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))  | ||||||
|  |     # dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \ | ||||||
|  |     #                             torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \ | ||||||
|  |     #                             / 2.5 | ||||||
|  | 
 | ||||||
|  |     if opt.bbox_coeff != 0.0: | ||||||
|  |         bbox_map = get_bbox_map(targets, attribution_map) | ||||||
|  |         attr_bbox_pos = attr_reg(attribution_map, bbox_map) | ||||||
|  |         attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map)) | ||||||
|  |         bbox_reg = attr_bbox_pos - attr_bbox_neg | ||||||
|  |         # bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map)) | ||||||
|  |     else: | ||||||
|  |         bbox_reg = 0.0 | ||||||
|  | 
 | ||||||
|  |     bbox_map = get_bbox_map(targets, attribution_map) | ||||||
|  |     plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map))) | ||||||
|  |     # iou_loss = (1.0 - plaus_score) | ||||||
|  | 
 | ||||||
|  |     if not opt.dist_reg_only: | ||||||
|  |         dist_reg_loss = (((1.0 + dist_reg) / 2.0)) | ||||||
|  |         plaus_reg = (plaus_score * opt.iou_coeff) + \ | ||||||
|  |                     (((dist_reg_loss * opt.dist_coeff) + \ | ||||||
|  |                       (bbox_reg * opt.bbox_coeff))\ | ||||||
|  |                     # ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\ | ||||||
|  |                     # / (plaus_score) \ | ||||||
|  |                     ) | ||||||
|  |     else: | ||||||
|  |         plaus_reg = (((1.0 + dist_reg) / 2.0)) | ||||||
|  |         # plaus_reg = dist_reg  | ||||||
|  |     # Calculate plausibility loss | ||||||
|  |     plaus_loss = (1 - plaus_reg) * opt.pgt_coeff | ||||||
|  |     # plaus_loss = (plaus_reg) * opt.pgt_coeff | ||||||
|  |     if only_loss: | ||||||
|  |         return plaus_loss | ||||||
|  |     if not debug: | ||||||
|  |         return plaus_loss, (plaus_score, dist_reg, plaus_reg,) | ||||||
|  |     else: | ||||||
|  |         return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map | ||||||
|  | 
 | ||||||
|  | #################################################################################### | ||||||
|  | #### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS #### | ||||||
|  | #################################################################################### | ||||||
|  | 
 | ||||||
|  | def generate_vanilla_grad(model, input_tensor, loss_func = None,  | ||||||
|  |                           targets_list=None, targets=None, metric=None, out_num = 1,  | ||||||
|  |                           n_max_labels=3, norm=True, abs=True, grayscale=True,  | ||||||
|  |                           class_specific_attr = True, device='cpu'):     | ||||||
|  |     """ | ||||||
|  |     Generate vanilla gradients for the given model and input tensor. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         model (nn.Module): The model to generate gradients for. | ||||||
|  |         input_tensor (torch.Tensor): The input tensor for which gradients are computed. | ||||||
|  |         loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None. | ||||||
|  |         targets_list (list, optional): The list of target tensors. Defaults to None. | ||||||
|  |         metric (callable, optional): The metric function to evaluate the loss. Defaults to None. | ||||||
|  |         out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1. | ||||||
|  |         n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3. | ||||||
|  |         norm (bool, optional): Whether to normalize the attribution map. Defaults to True. | ||||||
|  |         abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True. | ||||||
|  |         grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True. | ||||||
|  |         class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True. | ||||||
|  |         device (str, optional): The device to use for computation. Defaults to 'cpu'. | ||||||
|  |      | ||||||
|  |     Returns: | ||||||
|  |         torch.Tensor: The generated vanilla gradients. | ||||||
|  |     """ | ||||||
|  |     # Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end | ||||||
|  |     train_mode = model.training | ||||||
|  |     if not train_mode: | ||||||
|  |         model.train() | ||||||
|  |      | ||||||
|  |     input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients | ||||||
|  |     model.zero_grad() # Zero gradients | ||||||
|  |     inpt = input_tensor | ||||||
|  |     # Forward pass | ||||||
|  |     train_out = model(inpt) # training outputs (no inference outputs in train mode) | ||||||
|  |      | ||||||
|  |     # train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities) | ||||||
|  |     # train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling) | ||||||
|  |     # train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence) | ||||||
|  |      | ||||||
|  |     if class_specific_attr: | ||||||
|  |         n_attr_list, index_classes = [], [] | ||||||
|  |         for i in range(len(input_tensor)): | ||||||
|  |             if len(targets_list[i]) > n_max_labels: | ||||||
|  |                 targets_list[i] = targets_list[i][:n_max_labels] | ||||||
|  |             if targets_list[i].numel() != 0: | ||||||
|  |                 # unique_classes = torch.unique(targets_list[i][:,1]) | ||||||
|  |                 class_numbers = targets_list[i][:,1] | ||||||
|  |                 index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers]) | ||||||
|  |                 num_attrs = len(targets_list[i]) | ||||||
|  |                 # index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes]) | ||||||
|  |                 # num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i]) | ||||||
|  |                 n_attr_list.append(num_attrs) | ||||||
|  |             else: | ||||||
|  |                 index_classes.append([0, 1, 2, 3, 4]) | ||||||
|  |                 n_attr_list.append(0) | ||||||
|  |      | ||||||
|  |         targets_list_filled = [targ.clone().detach() for targ in targets_list] | ||||||
|  |         labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))] | ||||||
|  |         max_labels = np.max(labels_len) | ||||||
|  |         max_index = np.argmax(labels_len) | ||||||
|  |         for i in range(len(targets_list)): | ||||||
|  |             # targets_list_filled[i] = targets_list[i] | ||||||
|  |             if len(targets_list_filled[i]) < max_labels: | ||||||
|  |                 tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i])) | ||||||
|  |                 targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0) | ||||||
|  |             else: | ||||||
|  |                 targets_list_filled[i] = targets_list_filled[i].unsqueeze(0) | ||||||
|  |         for i in range(len(targets_list_filled)-1,-1,-1): | ||||||
|  |             if targets_list_filled[i].numel() == 0: | ||||||
|  |                 targets_list_filled.pop(i) | ||||||
|  |         targets_list_filled = torch.cat(targets_list_filled) | ||||||
|  |      | ||||||
|  |     n_img_attrs = len(input_tensor) if class_specific_attr else 1 | ||||||
|  |     n_img_attrs = 1 if loss_func else n_img_attrs | ||||||
|  |      | ||||||
|  |     attrs_batch = [] | ||||||
|  |     for i_batch in range(n_img_attrs): | ||||||
|  |         if loss_func and class_specific_attr: | ||||||
|  |             i_batch = max_index | ||||||
|  |         # inpt = input_tensor[i_batch].unsqueeze(0) | ||||||
|  |         # ################################################################## | ||||||
|  |         # model.zero_grad() # Zero gradients | ||||||
|  |         # train_out = model(inpt)  # training outputs (no inference outputs in train mode) | ||||||
|  |         # ################################################################## | ||||||
|  |         n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1 | ||||||
|  |         n_label_attrs = 1 if not class_specific_attr else n_label_attrs | ||||||
|  |         attrs_img = [] | ||||||
|  |         for i_attr in range(n_label_attrs): | ||||||
|  |             if loss_func is None: | ||||||
|  |                 grad_wrt = train_out[out_num] | ||||||
|  |                 if class_specific_attr: | ||||||
|  |                     grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]] | ||||||
|  |                 grad_wrt_outputs = torch.ones_like(grad_wrt) | ||||||
|  |             else: | ||||||
|  |                 # if class_specific_attr: | ||||||
|  |                 #     targets = targets_list[:][i_attr] | ||||||
|  |                 # n_targets = len(targets_list[i_batch]) | ||||||
|  |                 if class_specific_attr: | ||||||
|  |                     target_indiv = targets_list_filled[:,i_attr] # batch image input | ||||||
|  |                 else: | ||||||
|  |                     target_indiv = targets | ||||||
|  |                 # target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input | ||||||
|  |                 # target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time | ||||||
|  |                      | ||||||
|  |                 try: | ||||||
|  |                     loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)  # loss scaled by batch_size | ||||||
|  |                 except: | ||||||
|  |                     target_indiv = target_indiv.to(device) | ||||||
|  |                     inpt = inpt.to(device) | ||||||
|  |                     for tro in train_out: | ||||||
|  |                         tro = tro.to(device) | ||||||
|  |                     print("Error in loss function, trying again with device specified") | ||||||
|  |                     loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) | ||||||
|  |                 grad_wrt = loss | ||||||
|  |                 grad_wrt_outputs = None | ||||||
|  |              | ||||||
|  |             model.zero_grad() # Zero gradients | ||||||
|  |             gradients = torch.autograd.grad(grad_wrt, inpt,  | ||||||
|  |                                                 grad_outputs=grad_wrt_outputs,  | ||||||
|  |                                                 retain_graph=True,  | ||||||
|  |                                                 # create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly | ||||||
|  |                                                 ) | ||||||
|  | 
 | ||||||
|  |             # Convert gradients to numpy array and back to ensure full separation from graph | ||||||
|  |             # attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy()) | ||||||
|  |             attribution_map = gradients[0]#.clone().detach() # without converting to numpy | ||||||
|  |              | ||||||
|  |             if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval | ||||||
|  |                 attribution_map = torch.sum(attribution_map, 1, keepdim=True) | ||||||
|  |             if abs: | ||||||
|  |                 attribution_map = torch.abs(attribution_map) # Take absolute values of gradients | ||||||
|  |             if norm: | ||||||
|  |                 attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch | ||||||
|  |             attrs_img.append(attribution_map) | ||||||
|  |         if len(attrs_img) == 0: | ||||||
|  |             attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device)) | ||||||
|  |         else: | ||||||
|  |             attrs_batch.append(torch.stack(attrs_img).to(device)) | ||||||
|  | 
 | ||||||
|  |     # out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device) | ||||||
|  |     # out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch | ||||||
|  |     out_attr = attrs_batch | ||||||
|  |     # Set model back to original mode | ||||||
|  |     if not train_mode: | ||||||
|  |         model.eval() | ||||||
|  |      | ||||||
|  |     return out_attr | ||||||
|  | 
 | ||||||
|  | class RVNonLinearFunc(torch.nn.Module): | ||||||
|  |     """ | ||||||
|  |     Custom Bayesian ReLU activation function for random variables. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         None | ||||||
|  |     """ | ||||||
|  |     def __init__(self, func): | ||||||
|  |         super(RVNonLinearFunc, self).__init__() | ||||||
|  |         self.func = func | ||||||
|  | 
 | ||||||
|  |     def forward(self, mu_in, Sigma_in): | ||||||
|  |         """ | ||||||
|  |         Forward pass of the Bayesian ReLU activation function. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             mu_in (torch.Tensor): A tensor of shape (batch_size, input_size), | ||||||
|  |                 representing the mean input to the ReLU activation function. | ||||||
|  |             Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size), | ||||||
|  |                 representing the covariance input to the ReLU activation function. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors, | ||||||
|  |                 including the mean of the output and the covariance of the output. | ||||||
|  |         """ | ||||||
|  |         # Collect stats | ||||||
|  |         batch_size = mu_in.size(0) | ||||||
|  |         | ||||||
|  |         # Mean | ||||||
|  |         mu_out = self.func(mu_in) | ||||||
|  |          | ||||||
|  |         # Compute the derivative of the ReLU activation function with respect to the input mean | ||||||
|  |         gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1) | ||||||
|  | 
 | ||||||
|  |         # add an extra dimension to gradi at position 2 and 1 | ||||||
|  |         grad1 = gradi.unsqueeze(dim=2) | ||||||
|  |         grad2 = gradi.unsqueeze(dim=1) | ||||||
|  |         | ||||||
|  |         # compute the outer product of grad1 and grad2 | ||||||
|  |         outer_product = torch.bmm(grad1, grad2) | ||||||
|  |         | ||||||
|  |         # element-wise multiply Sigma_in with the outer product | ||||||
|  |         # and return the result | ||||||
|  |         Sigma_out = torch.mul(Sigma_in, outer_product) | ||||||
|  | 
 | ||||||
|  |         return mu_out, Sigma_out | ||||||
|  | 
 | ||||||
							
								
								
									
										154
									
								
								ultralytics/utils/plot_functs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								ultralytics/utils/plot_functs.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | |||||||
|  | import numpy as np | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | 
 | ||||||
|  | class Subplots: | ||||||
|  |     def __init__(self, figsize = (40, 5)): | ||||||
|  |         self.fig = plt.figure(figsize=figsize) | ||||||
|  |          | ||||||
|  |     def plot_img_list(self, img_list, savedir='figs/test',  | ||||||
|  |                     nrows = 1, rownum = 0,  | ||||||
|  |                     hold = False, coltitles=[], rowtitle=''): | ||||||
|  |          | ||||||
|  |         for i, img in enumerate(img_list): | ||||||
|  |             try: | ||||||
|  |                 npimg = img.clone().detach().cpu().numpy() | ||||||
|  |             except: | ||||||
|  |                 npimg = img | ||||||
|  |             tpimg = np.transpose(npimg, (1, 2, 0)) | ||||||
|  |             lenrow = int((len(img_list))) | ||||||
|  |             ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow)) | ||||||
|  |             if len(coltitles) > i: | ||||||
|  |                 ax.set_title(coltitles[i]) | ||||||
|  |             if i == 0: | ||||||
|  |                 ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0), | ||||||
|  |                 xycoords='axes fraction', textcoords='offset points', | ||||||
|  |                 size='large', ha='center', va='baseline') | ||||||
|  |                 # ax.set_ylabel(rowtitle, rotation=90) | ||||||
|  |             ax.imshow(tpimg) | ||||||
|  |             ax.axis('off') | ||||||
|  | 
 | ||||||
|  |         if not hold: | ||||||
|  |             self.fig.tight_layout() | ||||||
|  |             plt.savefig(f'{savedir}.png') | ||||||
|  |             plt.clf() | ||||||
|  |             plt.close('all') | ||||||
|  |              | ||||||
|  |                      | ||||||
|  | def VisualizeNumpyImageGrayscale(image_3d): | ||||||
|  |     r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. | ||||||
|  |     """ | ||||||
|  |     vmin = np.min(image_3d) | ||||||
|  |     image_2d = image_3d - vmin | ||||||
|  |     vmax = np.max(image_2d) | ||||||
|  |     return (image_2d / vmax) | ||||||
|  | 
 | ||||||
|  | def normalize_numpy(image_3d): | ||||||
|  |     r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. | ||||||
|  |     """ | ||||||
|  |     vmin = np.min(image_3d) | ||||||
|  |     image_2d = image_3d - vmin | ||||||
|  |     vmax = np.max(image_2d) | ||||||
|  |     return (image_2d / vmax) | ||||||
|  | 
 | ||||||
|  | # def normalize_tensor(image_3d):  | ||||||
|  | #     r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. | ||||||
|  | #     """ | ||||||
|  | #     vmin = torch.min(image_3d) | ||||||
|  | #     image_2d = image_3d - vmin | ||||||
|  | #     vmax = torch.max(image_2d) | ||||||
|  | #     return (image_2d / vmax) | ||||||
|  | 
 | ||||||
|  | def normalize_tensor(image_3d):  | ||||||
|  |     r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor. | ||||||
|  |     """ | ||||||
|  |     image_2d = (image_3d - torch.min(image_3d)) | ||||||
|  |     return (image_2d / torch.max(image_2d)) | ||||||
|  | 
 | ||||||
|  | def format_img(img_): | ||||||
|  |     np_img = img_.numpy() | ||||||
|  |     tp_img = np.transpose(np_img, (1, 2, 0)) | ||||||
|  |     return tp_img | ||||||
|  | 
 | ||||||
|  | def imshow(img, save_path=None): | ||||||
|  |     try: | ||||||
|  |         npimg = img.clone().detach().cpu().numpy() | ||||||
|  |     except: | ||||||
|  |         npimg = img | ||||||
|  |     tpimg = np.transpose(npimg, (1, 2, 0)) | ||||||
|  |     plt.imshow(tpimg) | ||||||
|  |     # plt.axis('off') | ||||||
|  |     plt.tight_layout() | ||||||
|  |     if save_path != None: | ||||||
|  |         plt.savefig(str(str(save_path) + ".png")) | ||||||
|  |     #plt.show()a | ||||||
|  | 
 | ||||||
|  | def imshow_img(img, imsave_path): | ||||||
|  |     # works for tensors and numpy arrays | ||||||
|  |     try: | ||||||
|  |         npimg = VisualizeNumpyImageGrayscale(img.numpy()) | ||||||
|  |     except: | ||||||
|  |         npimg = VisualizeNumpyImageGrayscale(img) | ||||||
|  |     npimg = np.transpose(npimg, (2, 0, 1)) | ||||||
|  |     imshow(npimg, save_path=imsave_path) | ||||||
|  |     print("Saving image as ", imsave_path) | ||||||
|  |      | ||||||
|  | def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'): | ||||||
|  |     model.train() | ||||||
|  |     model.to(device) | ||||||
|  |     img = img.to(device) | ||||||
|  |     img.requires_grad_(True) | ||||||
|  |     labels.to(device).requires_grad_(True) | ||||||
|  |     model.requires_grad_(True) | ||||||
|  |     cuda = device.type != 'cpu' | ||||||
|  |     scaler = amp.GradScaler(enabled=cuda) | ||||||
|  |     pred = model(img) | ||||||
|  |     # out, train_out = model(img, augment=augment)  # inference and training outputs | ||||||
|  |     loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3]  # box, obj, cls | ||||||
|  |     # loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device)) | ||||||
|  |     # loss = torch.sum(loss).requires_grad_(True) | ||||||
|  |      | ||||||
|  |     with torch.autograd.set_detect_anomaly(True): | ||||||
|  |         scaler.scale(loss).backward(inputs=img) | ||||||
|  |     # loss.backward() | ||||||
|  |      | ||||||
|  | #    S_c = torch.max(pred[0].data, 0)[0] | ||||||
|  |     Sc_dx = img.grad | ||||||
|  |     model.eval() | ||||||
|  |     Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32) | ||||||
|  |     return Sc_dx | ||||||
|  | 
 | ||||||
|  | def calculate_snr(img, attr, dB=True): | ||||||
|  |     try: | ||||||
|  |         img_np = img.detach().cpu().numpy() | ||||||
|  |         attr_np = attr.detach().cpu().numpy() | ||||||
|  |     except: | ||||||
|  |         img_np = img | ||||||
|  |         attr_np = attr | ||||||
|  |      | ||||||
|  |     # Calculate the signal power | ||||||
|  |     signal_power = np.mean(img_np**2) | ||||||
|  | 
 | ||||||
|  |     # Calculate the noise power | ||||||
|  |     noise_power = np.mean(attr_np**2) | ||||||
|  | 
 | ||||||
|  |     if dB == True: | ||||||
|  |         # Calculate SNR in dB | ||||||
|  |         snr = 10 * np.log10(signal_power / noise_power) | ||||||
|  |     else: | ||||||
|  |         # Calculate SNR | ||||||
|  |         snr = signal_power / noise_power | ||||||
|  | 
 | ||||||
|  |     return snr | ||||||
|  | 
 | ||||||
|  | def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7): | ||||||
|  |      | ||||||
|  |     cmap = plt.get_cmap(colormap) | ||||||
|  |     npmask = np.array(mask.clone().detach().cpu().squeeze(0)) | ||||||
|  |     # cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1)) | ||||||
|  |     cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1)) | ||||||
|  |     overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask)) | ||||||
|  |     overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device) | ||||||
|  |      | ||||||
|  |     return overlayed_tensor | ||||||
| @ -717,7 +717,7 @@ def plot_images( | |||||||
| ): | ): | ||||||
|     """Plot image grid with labels.""" |     """Plot image grid with labels.""" | ||||||
|     if isinstance(images, torch.Tensor): |     if isinstance(images, torch.Tensor): | ||||||
|         images = images.cpu().float().numpy() |         images = images.detach().cpu().float().numpy() | ||||||
|     if isinstance(cls, torch.Tensor): |     if isinstance(cls, torch.Tensor): | ||||||
|         cls = cls.cpu().numpy() |         cls = cls.cpu().numpy() | ||||||
|     if isinstance(bboxes, torch.Tensor): |     if isinstance(bboxes, torch.Tensor): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 nielseni6
						nielseni6