mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Fixed issues with pgt training, adding the pgt loss to the YOLO loss function. PGT loss can now be tracked during training.
This commit is contained in:
parent
44c912647e
commit
21e660fde9
@ -6,7 +6,7 @@ import argparse
|
||||
from datetime import datetime
|
||||
import torch
|
||||
|
||||
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
|
||||
# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 &
|
||||
|
||||
def main(args):
|
||||
model = YOLOv10PGT('yolov10n.pt')
|
||||
@ -15,7 +15,7 @@ def main(args):
|
||||
epochs=args.epochs,
|
||||
batch=args.batch_size,
|
||||
# amp=False,
|
||||
# pgt_coeff=1.5,
|
||||
pgt_coeff=3.0,
|
||||
# cfg='pgt_train.yaml', # Load and train model with the config file
|
||||
)
|
||||
# If you want to finetune the model with pretrained weights, you could load the
|
||||
@ -51,7 +51,6 @@ def main(args):
|
||||
model.save(model_save_path)
|
||||
# torch.save(trainer.model.state_dict(), model_save_path)
|
||||
|
||||
|
||||
# Evaluate the model on the validation set
|
||||
results = model.val(data=args.data_yaml)
|
||||
|
||||
@ -61,7 +60,7 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
||||
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
||||
args = parser.parse_args()
|
||||
|
@ -401,67 +401,9 @@ class PGTTrainer:
|
||||
debug_ = debug
|
||||
if debug_ and (i % 25 == 0):
|
||||
debug_ = False
|
||||
# Create a tensor of zeros with the same size as images
|
||||
mask = torch.zeros_like(images, dtype=torch.float32)
|
||||
smask = get_dist_reg(images, batch['masks'])
|
||||
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
|
||||
grad = torch.abs(grad)
|
||||
|
||||
batch_size = images.shape[0]
|
||||
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
plot_grads(batch, self, i)
|
||||
|
||||
# Iterate over each bounding box and set the corresponding pixels to 1
|
||||
for irx, bboxes in enumerate(gt_bboxes):
|
||||
for idx in range(len(bboxes)):
|
||||
x1, y1, x2, y2 = bboxes[idx]
|
||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||
mask[irx, :, y1:y2, x1:x2] = 1.0
|
||||
|
||||
save_imgs = True
|
||||
if save_imgs:
|
||||
# Convert tensors to numpy arrays
|
||||
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
|
||||
# Normalize grad for visualization
|
||||
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||
|
||||
for ix in range(images_np.shape[0]):
|
||||
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
||||
ax[0].imshow(images_np[ix])
|
||||
ax[0].set_title('Image')
|
||||
ax[1].imshow(grad_np[ix], cmap='jet')
|
||||
ax[1].set_title('Gradient')
|
||||
ax[2].imshow(images_np[ix])
|
||||
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
||||
ax[2].set_title('Overlay')
|
||||
ax[3].imshow(mask_np[ix], cmap='gray')
|
||||
ax[3].set_title('Mask')
|
||||
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
||||
ax[4].set_title('Segmentation Mask')
|
||||
|
||||
# Plot image with bounding boxes
|
||||
ax[5].imshow(images_np[ix])
|
||||
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
||||
ax[5].add_patch(rect)
|
||||
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
||||
ax[5].set_title('Bounding Boxes')
|
||||
|
||||
save_dir_attr = f"figures/attributions/run{self.num}"
|
||||
if not os.path.exists(save_dir_attr):
|
||||
os.makedirs(save_dir_attr)
|
||||
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
|
||||
plt.close(fig)
|
||||
images = images.detach()
|
||||
|
||||
|
||||
if RANK != -1:
|
||||
@ -500,6 +442,7 @@ class PGTTrainer:
|
||||
self.run_callbacks("on_batch_end")
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
plot_grads(batch, self, ni)
|
||||
|
||||
self.run_callbacks("on_train_batch_end")
|
||||
|
||||
@ -839,3 +782,69 @@ class PGTTrainer:
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def plot_grads(batch, obj, i, nsamples=16):
|
||||
# Create a tensor of zeros with the same size as images
|
||||
images = batch['img'].requires_grad_(True)
|
||||
mask = torch.zeros_like(images, dtype=torch.float32)
|
||||
smask = get_dist_reg(images, batch['masks'])
|
||||
loss, loss_items = obj.model(batch)
|
||||
grad = torch.autograd.grad(loss, images, retain_graph=True)[0]
|
||||
grad = torch.abs(grad)
|
||||
|
||||
batch_size = images.shape[0]
|
||||
imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device)
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# Iterate over each bounding box and set the corresponding pixels to 1
|
||||
for irx, bboxes in enumerate(gt_bboxes):
|
||||
for idx in range(len(bboxes)):
|
||||
x1, y1, x2, y2 = bboxes[idx]
|
||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||
mask[irx, :, y1:y2, x1:x2] = 1.0
|
||||
|
||||
save_imgs = True
|
||||
if save_imgs:
|
||||
# Convert tensors to numpy arrays
|
||||
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||
|
||||
# Normalize grad for visualization
|
||||
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||
|
||||
range_val = min(nsamples, images_np.shape[0])
|
||||
for ix in range(range_val):
|
||||
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
||||
ax[0].imshow(images_np[ix])
|
||||
ax[0].set_title('Image')
|
||||
ax[1].imshow(grad_np[ix], cmap='jet')
|
||||
ax[1].set_title('Gradient')
|
||||
ax[2].imshow(images_np[ix])
|
||||
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
||||
ax[2].set_title('Overlay')
|
||||
ax[3].imshow(mask_np[ix], cmap='gray')
|
||||
ax[3].set_title('Mask')
|
||||
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
||||
ax[4].set_title('Segmentation Mask')
|
||||
|
||||
# Plot image with bounding boxes
|
||||
ax[5].imshow(images_np[ix])
|
||||
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
||||
ax[5].add_patch(rect)
|
||||
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
||||
ax[5].set_title('Bounding Boxes')
|
||||
|
||||
save_dir_attr = f"{obj.save_dir._str}/attributions"
|
||||
if not os.path.exists(save_dir_attr):
|
||||
os.makedirs(save_dir_attr)
|
||||
plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png')
|
||||
plt.close(fig)
|
@ -11,7 +11,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppr
|
||||
from .metrics import bbox_iou
|
||||
import torchvision.transforms as T
|
||||
|
||||
def plaus_loss_fn(grad, smask, pgt_coeff):
|
||||
def plaus_loss_fn(grad, smask, pgt_coeff, square=True):
|
||||
################## Compute the PGT Loss ##################
|
||||
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
|
||||
@ -22,7 +22,7 @@ def plaus_loss_fn(grad, smask, pgt_coeff):
|
||||
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
|
||||
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||
# Calculate plausibility loss
|
||||
plaus_loss = (1 - plaus_reg) * pgt_coeff
|
||||
plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff
|
||||
return plaus_loss
|
||||
|
||||
def get_dist_reg(images, seg_mask):
|
||||
|
@ -837,6 +837,144 @@ def plot_images(
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
@threaded
|
||||
def plot_gradients(
|
||||
images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
confs=None,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname="images.jpg",
|
||||
names=None,
|
||||
on_plot=None,
|
||||
max_subplots=16,
|
||||
save=True,
|
||||
conf_thres=0.25,
|
||||
):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.detach().cpu().float().numpy()
|
||||
if isinstance(cls, torch.Tensor):
|
||||
cls = cls.cpu().numpy()
|
||||
if isinstance(bboxes, torch.Tensor):
|
||||
bboxes = bboxes.cpu().numpy()
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.cpu().numpy().astype(int)
|
||||
if isinstance(kpts, torch.Tensor):
|
||||
kpts = kpts.cpu().numpy()
|
||||
if isinstance(batch_idx, torch.Tensor):
|
||||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs**0.5) # number of subplots (square)
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255 # de-normalise (optional)
|
||||
|
||||
# Build Image
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
if scale < 1:
|
||||
h = math.ceil(scale * h)
|
||||
w = math.ceil(scale * w)
|
||||
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
||||
|
||||
# Annotate
|
||||
fs = int((h + w) * ns * 0.01) # font size
|
||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
classes = cls[idx].astype("int")
|
||||
labels = confs is None
|
||||
|
||||
if len(bboxes):
|
||||
boxes = bboxes[idx]
|
||||
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||
if len(boxes):
|
||||
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||
boxes[..., 0::2] *= w # scale to pixels
|
||||
boxes[..., 1::2] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
boxes[..., :4] *= scale
|
||||
boxes[..., 0::2] += x
|
||||
boxes[..., 1::2] += y
|
||||
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||
c = classes[j]
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
if labels or conf[j] > conf_thres:
|
||||
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
||||
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||
|
||||
elif len(classes):
|
||||
for c in classes:
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
|
||||
|
||||
# Plot keypoints
|
||||
if len(kpts):
|
||||
kpts_ = kpts[idx].copy()
|
||||
if len(kpts_):
|
||||
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
||||
kpts_[..., 0] *= w # scale to pixels
|
||||
kpts_[..., 1] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
kpts_ *= scale
|
||||
kpts_[..., 0] += x
|
||||
kpts_[..., 1] += y
|
||||
for j in range(len(kpts_)):
|
||||
if labels or conf[j] > conf_thres:
|
||||
annotator.kpts(kpts_[j])
|
||||
|
||||
# Plot masks
|
||||
if len(masks):
|
||||
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
||||
image_masks = masks[idx]
|
||||
else: # overlap_masks=True
|
||||
image_masks = masks[[i]] # (1, 640, 640)
|
||||
nl = idx.sum()
|
||||
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
||||
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
|
||||
im = np.asarray(annotator.im).copy()
|
||||
for j in range(len(image_masks)):
|
||||
if labels or conf[j] > conf_thres:
|
||||
color = colors(classes[j])
|
||||
mh, mw = image_masks[j].shape
|
||||
if mh != h or mw != w:
|
||||
mask = image_masks[j].astype(np.uint8)
|
||||
mask = cv2.resize(mask, (w, h))
|
||||
mask = mask.astype(bool)
|
||||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
im[y : y + h, x : x + w, :][mask] = (
|
||||
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
)
|
||||
annotator.fromarray(im)
|
||||
if not save:
|
||||
return np.asarray(annotator.im)
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user