mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fixed issues with pgt training, adding the pgt loss to the YOLO loss function. PGT loss can now be tracked during training.
This commit is contained in:
parent
44c912647e
commit
21e660fde9
@ -6,7 +6,7 @@ import argparse
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
|
# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 &
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = YOLOv10PGT('yolov10n.pt')
|
model = YOLOv10PGT('yolov10n.pt')
|
||||||
@ -15,7 +15,7 @@ def main(args):
|
|||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
batch=args.batch_size,
|
batch=args.batch_size,
|
||||||
# amp=False,
|
# amp=False,
|
||||||
# pgt_coeff=1.5,
|
pgt_coeff=3.0,
|
||||||
# cfg='pgt_train.yaml', # Load and train model with the config file
|
# cfg='pgt_train.yaml', # Load and train model with the config file
|
||||||
)
|
)
|
||||||
# If you want to finetune the model with pretrained weights, you could load the
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
@ -51,7 +51,6 @@ def main(args):
|
|||||||
model.save(model_save_path)
|
model.save(model_save_path)
|
||||||
# torch.save(trainer.model.state_dict(), model_save_path)
|
# torch.save(trainer.model.state_dict(), model_save_path)
|
||||||
|
|
||||||
|
|
||||||
# Evaluate the model on the validation set
|
# Evaluate the model on the validation set
|
||||||
results = model.val(data=args.data_yaml)
|
results = model.val(data=args.data_yaml)
|
||||||
|
|
||||||
@ -61,7 +60,7 @@ def main(args):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
||||||
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
||||||
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
||||||
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||||
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -401,67 +401,9 @@ class PGTTrainer:
|
|||||||
debug_ = debug
|
debug_ = debug
|
||||||
if debug_ and (i % 25 == 0):
|
if debug_ and (i % 25 == 0):
|
||||||
debug_ = False
|
debug_ = False
|
||||||
# Create a tensor of zeros with the same size as images
|
|
||||||
mask = torch.zeros_like(images, dtype=torch.float32)
|
|
||||||
smask = get_dist_reg(images, batch['masks'])
|
|
||||||
grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
|
|
||||||
grad = torch.abs(grad)
|
|
||||||
|
|
||||||
batch_size = images.shape[0]
|
plot_grads(batch, self, i)
|
||||||
imgsz = torch.tensor(batch['resized_shape'][0]).to(self.device)
|
|
||||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
||||||
targets = v8DetectionLoss.preprocess(self, targets=targets.to(self.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
||||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
||||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
|
||||||
|
|
||||||
# Iterate over each bounding box and set the corresponding pixels to 1
|
|
||||||
for irx, bboxes in enumerate(gt_bboxes):
|
|
||||||
for idx in range(len(bboxes)):
|
|
||||||
x1, y1, x2, y2 = bboxes[idx]
|
|
||||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
|
||||||
mask[irx, :, y1:y2, x1:x2] = 1.0
|
|
||||||
|
|
||||||
save_imgs = True
|
|
||||||
if save_imgs:
|
|
||||||
# Convert tensors to numpy arrays
|
|
||||||
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
|
||||||
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
|
||||||
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
|
||||||
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
|
||||||
|
|
||||||
# Normalize grad for visualization
|
|
||||||
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
|
||||||
|
|
||||||
for ix in range(images_np.shape[0]):
|
|
||||||
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
|
||||||
ax[0].imshow(images_np[ix])
|
|
||||||
ax[0].set_title('Image')
|
|
||||||
ax[1].imshow(grad_np[ix], cmap='jet')
|
|
||||||
ax[1].set_title('Gradient')
|
|
||||||
ax[2].imshow(images_np[ix])
|
|
||||||
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
|
||||||
ax[2].set_title('Overlay')
|
|
||||||
ax[3].imshow(mask_np[ix], cmap='gray')
|
|
||||||
ax[3].set_title('Mask')
|
|
||||||
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
|
||||||
ax[4].set_title('Segmentation Mask')
|
|
||||||
|
|
||||||
# Plot image with bounding boxes
|
|
||||||
ax[5].imshow(images_np[ix])
|
|
||||||
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
|
||||||
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
|
||||||
ax[5].add_patch(rect)
|
|
||||||
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
|
||||||
ax[5].set_title('Bounding Boxes')
|
|
||||||
|
|
||||||
save_dir_attr = f"figures/attributions/run{self.num}"
|
|
||||||
if not os.path.exists(save_dir_attr):
|
|
||||||
os.makedirs(save_dir_attr)
|
|
||||||
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
|
|
||||||
plt.close(fig)
|
|
||||||
images = images.detach()
|
|
||||||
|
|
||||||
|
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
@ -500,6 +442,7 @@ class PGTTrainer:
|
|||||||
self.run_callbacks("on_batch_end")
|
self.run_callbacks("on_batch_end")
|
||||||
if self.args.plots and ni in self.plot_idx:
|
if self.args.plots and ni in self.plot_idx:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
|
plot_grads(batch, self, ni)
|
||||||
|
|
||||||
self.run_callbacks("on_train_batch_end")
|
self.run_callbacks("on_train_batch_end")
|
||||||
|
|
||||||
@ -839,3 +782,69 @@ class PGTTrainer:
|
|||||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||||
)
|
)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def plot_grads(batch, obj, i, nsamples=16):
|
||||||
|
# Create a tensor of zeros with the same size as images
|
||||||
|
images = batch['img'].requires_grad_(True)
|
||||||
|
mask = torch.zeros_like(images, dtype=torch.float32)
|
||||||
|
smask = get_dist_reg(images, batch['masks'])
|
||||||
|
loss, loss_items = obj.model(batch)
|
||||||
|
grad = torch.autograd.grad(loss, images, retain_graph=True)[0]
|
||||||
|
grad = torch.abs(grad)
|
||||||
|
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device)
|
||||||
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||||
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
|
|
||||||
|
# Iterate over each bounding box and set the corresponding pixels to 1
|
||||||
|
for irx, bboxes in enumerate(gt_bboxes):
|
||||||
|
for idx in range(len(bboxes)):
|
||||||
|
x1, y1, x2, y2 = bboxes[idx]
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
mask[irx, :, y1:y2, x1:x2] = 1.0
|
||||||
|
|
||||||
|
save_imgs = True
|
||||||
|
if save_imgs:
|
||||||
|
# Convert tensors to numpy arrays
|
||||||
|
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# Normalize grad for visualization
|
||||||
|
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||||
|
|
||||||
|
range_val = min(nsamples, images_np.shape[0])
|
||||||
|
for ix in range(range_val):
|
||||||
|
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
||||||
|
ax[0].imshow(images_np[ix])
|
||||||
|
ax[0].set_title('Image')
|
||||||
|
ax[1].imshow(grad_np[ix], cmap='jet')
|
||||||
|
ax[1].set_title('Gradient')
|
||||||
|
ax[2].imshow(images_np[ix])
|
||||||
|
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
||||||
|
ax[2].set_title('Overlay')
|
||||||
|
ax[3].imshow(mask_np[ix], cmap='gray')
|
||||||
|
ax[3].set_title('Mask')
|
||||||
|
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
||||||
|
ax[4].set_title('Segmentation Mask')
|
||||||
|
|
||||||
|
# Plot image with bounding boxes
|
||||||
|
ax[5].imshow(images_np[ix])
|
||||||
|
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
||||||
|
ax[5].add_patch(rect)
|
||||||
|
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
||||||
|
ax[5].set_title('Bounding Boxes')
|
||||||
|
|
||||||
|
save_dir_attr = f"{obj.save_dir._str}/attributions"
|
||||||
|
if not os.path.exists(save_dir_attr):
|
||||||
|
os.makedirs(save_dir_attr)
|
||||||
|
plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png')
|
||||||
|
plt.close(fig)
|
@ -11,7 +11,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppr
|
|||||||
from .metrics import bbox_iou
|
from .metrics import bbox_iou
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
def plaus_loss_fn(grad, smask, pgt_coeff):
|
def plaus_loss_fn(grad, smask, pgt_coeff, square=True):
|
||||||
################## Compute the PGT Loss ##################
|
################## Compute the PGT Loss ##################
|
||||||
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||||
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
|
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
|
||||||
@ -22,7 +22,7 @@ def plaus_loss_fn(grad, smask, pgt_coeff):
|
|||||||
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
|
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
|
||||||
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||||
# Calculate plausibility loss
|
# Calculate plausibility loss
|
||||||
plaus_loss = (1 - plaus_reg) * pgt_coeff
|
plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff
|
||||||
return plaus_loss
|
return plaus_loss
|
||||||
|
|
||||||
def get_dist_reg(images, seg_mask):
|
def get_dist_reg(images, seg_mask):
|
||||||
|
@ -837,6 +837,144 @@ def plot_images(
|
|||||||
if on_plot:
|
if on_plot:
|
||||||
on_plot(fname)
|
on_plot(fname)
|
||||||
|
|
||||||
|
@threaded
|
||||||
|
def plot_gradients(
|
||||||
|
images,
|
||||||
|
batch_idx,
|
||||||
|
cls,
|
||||||
|
bboxes=np.zeros(0, dtype=np.float32),
|
||||||
|
confs=None,
|
||||||
|
masks=np.zeros(0, dtype=np.uint8),
|
||||||
|
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||||
|
paths=None,
|
||||||
|
fname="images.jpg",
|
||||||
|
names=None,
|
||||||
|
on_plot=None,
|
||||||
|
max_subplots=16,
|
||||||
|
save=True,
|
||||||
|
conf_thres=0.25,
|
||||||
|
):
|
||||||
|
"""Plot image grid with labels."""
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
images = images.detach().cpu().float().numpy()
|
||||||
|
if isinstance(cls, torch.Tensor):
|
||||||
|
cls = cls.cpu().numpy()
|
||||||
|
if isinstance(bboxes, torch.Tensor):
|
||||||
|
bboxes = bboxes.cpu().numpy()
|
||||||
|
if isinstance(masks, torch.Tensor):
|
||||||
|
masks = masks.cpu().numpy().astype(int)
|
||||||
|
if isinstance(kpts, torch.Tensor):
|
||||||
|
kpts = kpts.cpu().numpy()
|
||||||
|
if isinstance(batch_idx, torch.Tensor):
|
||||||
|
batch_idx = batch_idx.cpu().numpy()
|
||||||
|
|
||||||
|
max_size = 1920 # max image size
|
||||||
|
bs, _, h, w = images.shape # batch size, _, height, width
|
||||||
|
bs = min(bs, max_subplots) # limit plot images
|
||||||
|
ns = np.ceil(bs**0.5) # number of subplots (square)
|
||||||
|
if np.max(images[0]) <= 1:
|
||||||
|
images *= 255 # de-normalise (optional)
|
||||||
|
|
||||||
|
# Build Image
|
||||||
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||||
|
for i in range(bs):
|
||||||
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||||
|
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Resize (optional)
|
||||||
|
scale = max_size / ns / max(h, w)
|
||||||
|
if scale < 1:
|
||||||
|
h = math.ceil(scale * h)
|
||||||
|
w = math.ceil(scale * w)
|
||||||
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
||||||
|
|
||||||
|
# Annotate
|
||||||
|
fs = int((h + w) * ns * 0.01) # font size
|
||||||
|
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||||
|
for i in range(bs):
|
||||||
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||||
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||||
|
if paths:
|
||||||
|
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||||
|
if len(cls) > 0:
|
||||||
|
idx = batch_idx == i
|
||||||
|
classes = cls[idx].astype("int")
|
||||||
|
labels = confs is None
|
||||||
|
|
||||||
|
if len(bboxes):
|
||||||
|
boxes = bboxes[idx]
|
||||||
|
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||||
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||||
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||||
|
if len(boxes):
|
||||||
|
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||||
|
boxes[..., 0::2] *= w # scale to pixels
|
||||||
|
boxes[..., 1::2] *= h
|
||||||
|
elif scale < 1: # absolute coords need scale if image scales
|
||||||
|
boxes[..., :4] *= scale
|
||||||
|
boxes[..., 0::2] += x
|
||||||
|
boxes[..., 1::2] += y
|
||||||
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||||
|
c = classes[j]
|
||||||
|
color = colors(c)
|
||||||
|
c = names.get(c, c) if names else c
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
||||||
|
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||||
|
|
||||||
|
elif len(classes):
|
||||||
|
for c in classes:
|
||||||
|
color = colors(c)
|
||||||
|
c = names.get(c, c) if names else c
|
||||||
|
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
|
||||||
|
|
||||||
|
# Plot keypoints
|
||||||
|
if len(kpts):
|
||||||
|
kpts_ = kpts[idx].copy()
|
||||||
|
if len(kpts_):
|
||||||
|
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
||||||
|
kpts_[..., 0] *= w # scale to pixels
|
||||||
|
kpts_[..., 1] *= h
|
||||||
|
elif scale < 1: # absolute coords need scale if image scales
|
||||||
|
kpts_ *= scale
|
||||||
|
kpts_[..., 0] += x
|
||||||
|
kpts_[..., 1] += y
|
||||||
|
for j in range(len(kpts_)):
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
annotator.kpts(kpts_[j])
|
||||||
|
|
||||||
|
# Plot masks
|
||||||
|
if len(masks):
|
||||||
|
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
||||||
|
image_masks = masks[idx]
|
||||||
|
else: # overlap_masks=True
|
||||||
|
image_masks = masks[[i]] # (1, 640, 640)
|
||||||
|
nl = idx.sum()
|
||||||
|
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
||||||
|
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||||
|
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||||
|
|
||||||
|
im = np.asarray(annotator.im).copy()
|
||||||
|
for j in range(len(image_masks)):
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
color = colors(classes[j])
|
||||||
|
mh, mw = image_masks[j].shape
|
||||||
|
if mh != h or mw != w:
|
||||||
|
mask = image_masks[j].astype(np.uint8)
|
||||||
|
mask = cv2.resize(mask, (w, h))
|
||||||
|
mask = mask.astype(bool)
|
||||||
|
else:
|
||||||
|
mask = image_masks[j].astype(bool)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
im[y : y + h, x : x + w, :][mask] = (
|
||||||
|
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||||
|
)
|
||||||
|
annotator.fromarray(im)
|
||||||
|
if not save:
|
||||||
|
return np.asarray(annotator.im)
|
||||||
|
annotator.im.save(fname) # save
|
||||||
|
if on_plot:
|
||||||
|
on_plot(fname)
|
||||||
|
|
||||||
@plt_settings()
|
@plt_settings()
|
||||||
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user