Fixed issues with pgt training, adding the pgt loss to the YOLO loss function. PGT loss can now be tracked during training.

This commit is contained in:
nielseni6 2024-11-01 11:04:12 -04:00
parent 44c912647e
commit 21e660fde9
4 changed files with 211 additions and 65 deletions

View File

@ -6,7 +6,7 @@ import argparse
from datetime import datetime
import torch
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 &
def main(args):
model = YOLOv10PGT('yolov10n.pt')
@ -15,7 +15,7 @@ def main(args):
epochs=args.epochs,
batch=args.batch_size,
# amp=False,
# pgt_coeff=1.5,
pgt_coeff=3.0,
# cfg='pgt_train.yaml', # Load and train model with the config file
)
# If you want to finetune the model with pretrained weights, you could load the
@ -51,7 +51,6 @@ def main(args):
model.save(model_save_path)
# torch.save(trainer.model.state_dict(), model_save_path)
# Evaluate the model on the validation set
results = model.val(data=args.data_yaml)
@ -61,7 +60,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
args = parser.parse_args()

View File

@ -401,67 +401,9 @@ class PGTTrainer:
debug_ = debug
if debug_ and (i % 25 == 0):
debug_ = False
# Create a tensor of zeros with the same size as images
mask = torch.zeros_like(images, dtype=torch.float32)
smask = get_dist_reg(images, batch['masks'])
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
grad = torch.abs(grad)
batch_size = images.shape[0]
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
plot_grads(batch, self, i)
# Iterate over each bounding box and set the corresponding pixels to 1
for irx, bboxes in enumerate(gt_bboxes):
for idx in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[idx]
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
mask[irx, :, y1:y2, x1:x2] = 1.0
save_imgs = True
if save_imgs:
# Convert tensors to numpy arrays
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
# Normalize grad for visualization
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
for ix in range(images_np.shape[0]):
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
ax[0].imshow(images_np[ix])
ax[0].set_title('Image')
ax[1].imshow(grad_np[ix], cmap='jet')
ax[1].set_title('Gradient')
ax[2].imshow(images_np[ix])
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
ax[2].set_title('Overlay')
ax[3].imshow(mask_np[ix], cmap='gray')
ax[3].set_title('Mask')
ax[4].imshow(seg_mask_np[ix], cmap='gray')
ax[4].set_title('Segmentation Mask')
# Plot image with bounding boxes
ax[5].imshow(images_np[ix])
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
ax[5].add_patch(rect)
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
ax[5].set_title('Bounding Boxes')
save_dir_attr = f"figures/attributions/run{self.num}"
if not os.path.exists(save_dir_attr):
os.makedirs(save_dir_attr)
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
plt.close(fig)
images = images.detach()
if RANK != -1:
@ -500,6 +442,7 @@ class PGTTrainer:
self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
plot_grads(batch, self, ni)
self.run_callbacks("on_train_batch_end")
@ -839,3 +782,69 @@ class PGTTrainer:
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
)
return optimizer
def plot_grads(batch, obj, i, nsamples=16):
# Create a tensor of zeros with the same size as images
images = batch['img'].requires_grad_(True)
mask = torch.zeros_like(images, dtype=torch.float32)
smask = get_dist_reg(images, batch['masks'])
loss, loss_items = obj.model(batch)
grad = torch.autograd.grad(loss, images, retain_graph=True)[0]
grad = torch.abs(grad)
batch_size = images.shape[0]
imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device)
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# Iterate over each bounding box and set the corresponding pixels to 1
for irx, bboxes in enumerate(gt_bboxes):
for idx in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[idx]
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
mask[irx, :, y1:y2, x1:x2] = 1.0
save_imgs = True
if save_imgs:
# Convert tensors to numpy arrays
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
# Normalize grad for visualization
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
range_val = min(nsamples, images_np.shape[0])
for ix in range(range_val):
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
ax[0].imshow(images_np[ix])
ax[0].set_title('Image')
ax[1].imshow(grad_np[ix], cmap='jet')
ax[1].set_title('Gradient')
ax[2].imshow(images_np[ix])
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
ax[2].set_title('Overlay')
ax[3].imshow(mask_np[ix], cmap='gray')
ax[3].set_title('Mask')
ax[4].imshow(seg_mask_np[ix], cmap='gray')
ax[4].set_title('Segmentation Mask')
# Plot image with bounding boxes
ax[5].imshow(images_np[ix])
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
ax[5].add_patch(rect)
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
ax[5].set_title('Bounding Boxes')
save_dir_attr = f"{obj.save_dir._str}/attributions"
if not os.path.exists(save_dir_attr):
os.makedirs(save_dir_attr)
plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png')
plt.close(fig)

View File

@ -11,7 +11,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppr
from .metrics import bbox_iou
import torchvision.transforms as T
def plaus_loss_fn(grad, smask, pgt_coeff):
def plaus_loss_fn(grad, smask, pgt_coeff, square=True):
################## Compute the PGT Loss ##################
# Positive regularization term for incentivizing pixels near the target to have high attribution
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
@ -22,7 +22,7 @@ def plaus_loss_fn(grad, smask, pgt_coeff):
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
plaus_reg = (((1.0 + dist_reg) / 2.0))
# Calculate plausibility loss
plaus_loss = (1 - plaus_reg) * pgt_coeff
plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff
return plaus_loss
def get_dist_reg(images, seg_mask):

View File

@ -837,6 +837,144 @@ def plot_images(
if on_plot:
on_plot(fname)
@threaded
def plot_gradients(
images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname="images.jpg",
names=None,
on_plot=None,
max_subplots=16,
save=True,
conf_thres=0.25,
):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.detach().cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy().astype(int)
if isinstance(kpts, torch.Tensor):
kpts = kpts.cpu().numpy()
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i
classes = cls[idx].astype("int")
labels = confs is None
if len(bboxes):
boxes = bboxes[idx]
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
if len(boxes):
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
boxes[..., 0::2] *= w # scale to pixels
boxes[..., 1::2] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes[..., :4] *= scale
boxes[..., 0::2] += x
boxes[..., 1::2] += y
for j, box in enumerate(boxes.astype(np.int64).tolist()):
c = classes[j]
color = colors(c)
c = names.get(c, c) if names else c
if labels or conf[j] > conf_thres:
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
annotator.box_label(box, label, color=color, rotated=is_obb)
elif len(classes):
for c in classes:
color = colors(c)
c = names.get(c, c) if names else c
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# Plot keypoints
if len(kpts):
kpts_ = kpts[idx].copy()
if len(kpts_):
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
kpts_[..., 0] *= w # scale to pixels
kpts_[..., 1] *= h
elif scale < 1: # absolute coords need scale if image scales
kpts_ *= scale
kpts_[..., 0] += x
kpts_[..., 1] += y
for j in range(len(kpts_)):
if labels or conf[j] > conf_thres:
annotator.kpts(kpts_[j])
# Plot masks
if len(masks):
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
image_masks = masks[idx]
else: # overlap_masks=True
image_masks = masks[[i]] # (1, 640, 640)
nl = idx.sum()
index = np.arange(nl).reshape((nl, 1, 1)) + 1
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
im = np.asarray(annotator.im).copy()
for j in range(len(image_masks)):
if labels or conf[j] > conf_thres:
color = colors(classes[j])
mh, mw = image_masks[j].shape
if mh != h or mw != w:
mask = image_masks[j].astype(np.uint8)
mask = cv2.resize(mask, (w, h))
mask = mask.astype(bool)
else:
mask = image_masks[j].astype(bool)
with contextlib.suppress(Exception):
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
annotator.fromarray(im)
if not save:
return np.asarray(annotator.im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):