From 21e660fde9f09ad1fffe17c50e2a9b9db8109bd8 Mon Sep 17 00:00:00 2001 From: nielseni6 Date: Fri, 1 Nov 2024 11:04:12 -0400 Subject: [PATCH] Fixed issues with pgt training, adding the pgt loss to the YOLO loss function. PGT loss can now be tracked during training. --- run_pgt_train.py | 7 +- ultralytics/engine/pgt_trainer.py | 127 ++++++++++++++------------- ultralytics/utils/plaus_functs.py | 4 +- ultralytics/utils/plotting.py | 138 ++++++++++++++++++++++++++++++ 4 files changed, 211 insertions(+), 65 deletions(-) diff --git a/run_pgt_train.py b/run_pgt_train.py index 6a8f3ab8..9bb8aae6 100644 --- a/run_pgt_train.py +++ b/run_pgt_train.py @@ -6,7 +6,7 @@ import argparse from datetime import datetime import torch -# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 & +# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 & def main(args): model = YOLOv10PGT('yolov10n.pt') @@ -15,7 +15,7 @@ def main(args): epochs=args.epochs, batch=args.batch_size, # amp=False, - # pgt_coeff=1.5, + pgt_coeff=3.0, # cfg='pgt_train.yaml', # Load and train model with the config file ) # If you want to finetune the model with pretrained weights, you could load the @@ -51,7 +51,6 @@ def main(args): model.save(model_save_path) # torch.save(trainer.model.state_dict(), model_save_path) - # Evaluate the model on the validation set results = model.val(data=args.data_yaml) @@ -61,7 +60,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.') parser.add_argument('--device', type=str, default='0', help='CUDA device number') - parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training') + parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training') parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file') args = parser.parse_args() diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py index 559db7c2..36a0276d 100644 --- a/ultralytics/engine/pgt_trainer.py +++ b/ultralytics/engine/pgt_trainer.py @@ -401,67 +401,9 @@ class PGTTrainer: debug_ = debug if debug_ and (i % 25 == 0): debug_ = False - # Create a tensor of zeros with the same size as images - mask = torch.zeros_like(images, dtype=torch.float32) - smask = get_dist_reg(images, batch['masks']) - grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0] - grad = torch.abs(grad) - batch_size = images.shape[0] - imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device) - targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) - targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + plot_grads(batch, self, i) - # Iterate over each bounding box and set the corresponding pixels to 1 - for irx, bboxes in enumerate(gt_bboxes): - for idx in range(len(bboxes)): - x1, y1, x2, y2 = bboxes[idx] - x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) - mask[irx, :, y1:y2, x1:x2] = 1.0 - - save_imgs = True - if save_imgs: - # Convert tensors to numpy arrays - images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) - grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) - mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1) - seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1) - - # Normalize grad for visualization - grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) - - for ix in range(images_np.shape[0]): - fig, ax = plt.subplots(1, 6, figsize=(30, 5)) - ax[0].imshow(images_np[ix]) - ax[0].set_title('Image') - ax[1].imshow(grad_np[ix], cmap='jet') - ax[1].set_title('Gradient') - ax[2].imshow(images_np[ix]) - ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5) - ax[2].set_title('Overlay') - ax[3].imshow(mask_np[ix], cmap='gray') - ax[3].set_title('Mask') - ax[4].imshow(seg_mask_np[ix], cmap='gray') - ax[4].set_title('Segmentation Mask') - - # Plot image with bounding boxes - ax[5].imshow(images_np[ix]) - for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]): - x1, y1, x2, y2 = bbox - x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) - rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2) - ax[5].add_patch(rect) - ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5)) - ax[5].set_title('Bounding Boxes') - - save_dir_attr = f"figures/attributions/run{self.num}" - if not os.path.exists(save_dir_attr): - os.makedirs(save_dir_attr) - plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') - plt.close(fig) - images = images.detach() if RANK != -1: @@ -500,6 +442,7 @@ class PGTTrainer: self.run_callbacks("on_batch_end") if self.args.plots and ni in self.plot_idx: self.plot_training_samples(batch, ni) + plot_grads(batch, self, ni) self.run_callbacks("on_train_batch_end") @@ -839,3 +782,69 @@ class PGTTrainer: f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' ) return optimizer + + +def plot_grads(batch, obj, i, nsamples=16): + # Create a tensor of zeros with the same size as images + images = batch['img'].requires_grad_(True) + mask = torch.zeros_like(images, dtype=torch.float32) + smask = get_dist_reg(images, batch['masks']) + loss, loss_items = obj.model(batch) + grad = torch.autograd.grad(loss, images, retain_graph=True)[0] + grad = torch.abs(grad) + + batch_size = images.shape[0] + imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device) + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + + # Iterate over each bounding box and set the corresponding pixels to 1 + for irx, bboxes in enumerate(gt_bboxes): + for idx in range(len(bboxes)): + x1, y1, x2, y2 = bboxes[idx] + x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) + mask[irx, :, y1:y2, x1:x2] = 1.0 + + save_imgs = True + if save_imgs: + # Convert tensors to numpy arrays + images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) + grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) + mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1) + seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1) + + # Normalize grad for visualization + grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) + + range_val = min(nsamples, images_np.shape[0]) + for ix in range(range_val): + fig, ax = plt.subplots(1, 6, figsize=(30, 5)) + ax[0].imshow(images_np[ix]) + ax[0].set_title('Image') + ax[1].imshow(grad_np[ix], cmap='jet') + ax[1].set_title('Gradient') + ax[2].imshow(images_np[ix]) + ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5) + ax[2].set_title('Overlay') + ax[3].imshow(mask_np[ix], cmap='gray') + ax[3].set_title('Mask') + ax[4].imshow(seg_mask_np[ix], cmap='gray') + ax[4].set_title('Segmentation Mask') + + # Plot image with bounding boxes + ax[5].imshow(images_np[ix]) + for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]): + x1, y1, x2, y2 = bbox + x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2)) + rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2) + ax[5].add_patch(rect) + ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5)) + ax[5].set_title('Bounding Boxes') + + save_dir_attr = f"{obj.save_dir._str}/attributions" + if not os.path.exists(save_dir_attr): + os.makedirs(save_dir_attr) + plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png') + plt.close(fig) \ No newline at end of file diff --git a/ultralytics/utils/plaus_functs.py b/ultralytics/utils/plaus_functs.py index 96fb8b0b..2ca60ba9 100644 --- a/ultralytics/utils/plaus_functs.py +++ b/ultralytics/utils/plaus_functs.py @@ -11,7 +11,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppr from .metrics import bbox_iou import torchvision.transforms as T -def plaus_loss_fn(grad, smask, pgt_coeff): +def plaus_loss_fn(grad, smask, pgt_coeff, square=True): ################## Compute the PGT Loss ################## # Positive regularization term for incentivizing pixels near the target to have high attribution dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask @@ -22,7 +22,7 @@ def plaus_loss_fn(grad, smask, pgt_coeff): dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad))) plaus_reg = (((1.0 + dist_reg) / 2.0)) # Calculate plausibility loss - plaus_loss = (1 - plaus_reg) * pgt_coeff + plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff return plaus_loss def get_dist_reg(images, seg_mask): diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 0e99c722..e94b8e63 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -837,6 +837,144 @@ def plot_images( if on_plot: on_plot(fname) +@threaded +def plot_gradients( + images, + batch_idx, + cls, + bboxes=np.zeros(0, dtype=np.float32), + confs=None, + masks=np.zeros(0, dtype=np.uint8), + kpts=np.zeros((0, 51), dtype=np.float32), + paths=None, + fname="images.jpg", + names=None, + on_plot=None, + max_subplots=16, + save=True, + conf_thres=0.25, +): + """Plot image grid with labels.""" + if isinstance(images, torch.Tensor): + images = images.detach().cpu().float().numpy() + if isinstance(cls, torch.Tensor): + cls = cls.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + if isinstance(masks, torch.Tensor): + masks = masks.cpu().numpy().astype(int) + if isinstance(kpts, torch.Tensor): + kpts = kpts.cpu().numpy() + if isinstance(batch_idx, torch.Tensor): + batch_idx = batch_idx.cpu().numpy() + + max_size = 1920 # max image size + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs**0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(cls) > 0: + idx = batch_idx == i + classes = cls[idx].astype("int") + labels = confs is None + + if len(bboxes): + boxes = bboxes[idx] + conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) + is_obb = boxes.shape[-1] == 5 # xywhr + boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) + if len(boxes): + if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 + boxes[..., 0::2] *= w # scale to pixels + boxes[..., 1::2] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes[..., :4] *= scale + boxes[..., 0::2] += x + boxes[..., 1::2] += y + for j, box in enumerate(boxes.astype(np.int64).tolist()): + c = classes[j] + color = colors(c) + c = names.get(c, c) if names else c + if labels or conf[j] > conf_thres: + label = f"{c}" if labels else f"{c} {conf[j]:.1f}" + annotator.box_label(box, label, color=color, rotated=is_obb) + + elif len(classes): + for c in classes: + color = colors(c) + c = names.get(c, c) if names else c + annotator.text((x, y), f"{c}", txt_color=color, box_style=True) + + # Plot keypoints + if len(kpts): + kpts_ = kpts[idx].copy() + if len(kpts_): + if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01 + kpts_[..., 0] *= w # scale to pixels + kpts_[..., 1] *= h + elif scale < 1: # absolute coords need scale if image scales + kpts_ *= scale + kpts_[..., 0] += x + kpts_[..., 1] += y + for j in range(len(kpts_)): + if labels or conf[j] > conf_thres: + annotator.kpts(kpts_[j]) + + # Plot masks + if len(masks): + if idx.shape[0] == masks.shape[0]: # overlap_masks=False + image_masks = masks[idx] + else: # overlap_masks=True + image_masks = masks[[i]] # (1, 640, 640) + nl = idx.sum() + index = np.arange(nl).reshape((nl, 1, 1)) + 1 + image_masks = np.repeat(image_masks, nl, axis=0) + image_masks = np.where(image_masks == index, 1.0, 0.0) + + im = np.asarray(annotator.im).copy() + for j in range(len(image_masks)): + if labels or conf[j] > conf_thres: + color = colors(classes[j]) + mh, mw = image_masks[j].shape + if mh != h or mw != w: + mask = image_masks[j].astype(np.uint8) + mask = cv2.resize(mask, (w, h)) + mask = mask.astype(bool) + else: + mask = image_masks[j].astype(bool) + with contextlib.suppress(Exception): + im[y : y + h, x : x + w, :][mask] = ( + im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 + ) + annotator.fromarray(im) + if not save: + return np.asarray(annotator.im) + annotator.im.save(fname) # save + if on_plot: + on_plot(fname) @plt_settings() def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):