updated grad to create_graph during training, enabling better pgt_loss optimization

This commit is contained in:
nielseni6 2024-10-22 23:40:46 -04:00
parent 411157c18a
commit 0efbfe7f4d
2 changed files with 34 additions and 13 deletions

View File

@ -5,6 +5,7 @@ from ultralytics.models.yolo.segment import PGTSegmentationTrainer
import argparse import argparse
from datetime import datetime from datetime import datetime
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
def main(args): def main(args):
# model = YOLOv10() # model = YOLOv10()
@ -14,13 +15,16 @@ def main(args):
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
# or # or
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
model = YOLOv10('yolov10n.pt', task='segment') # model = YOLOv10('yolov10n.pt', task='segment')
args = dict(model='yolov10n.pt', data=args.data_yaml, args_dict = dict(
model='yolov10n.pt',
data=args.data_yaml,
epochs=args.epochs, batch=args.batch_size, epochs=args.epochs, batch=args.batch_size,
# pgt_coeff=5.0,
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
) )
trainer = PGTSegmentationTrainer(overrides=args) trainer = PGTSegmentationTrainer(overrides=args_dict)
trainer.train( trainer.train(
# debug=True, # debug=True,
# args = dict(pgt_coeff=0.1), # Should add later to config # args = dict(pgt_coeff=0.1), # Should add later to config
@ -35,10 +39,10 @@ def main(args):
current_time = datetime.now().strftime('%Y%m%d_%H%M%S') current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0] data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt') model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
model.save(model_save_path) trainer.model.save(model_save_path)
# Evaluate the model on the validation set # Evaluate the model on the validation set
results = model.val(data='coco.yaml') results = trainer.val(data=args.data_yaml)
# Print the evaluation results # Print the evaluation results
print(results) print(results)
@ -55,3 +59,4 @@ if __name__ == "__main__":
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device os.environ["CUDA_VISIBLE_DEVICES"] = args.device
main(args) main(args)

View File

@ -727,11 +727,12 @@ class v10DetectLoss:
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
class v10PGTDetectLoss: class v10PGTDetectLoss:
def __init__(self, model): def __init__(self, model, pgt_coeff=3.0):
self.one2many = v8DetectionLoss(model, tal_topk=10) self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1) self.one2one = v8DetectionLoss(model, tal_topk=1)
self.pgt_coeff = pgt_coeff
def __call__(self, preds, batch, return_plaus=True): def __call__(self, preds, batch, return_plaus=True, inference=False):
batch['img'] = batch['img'].requires_grad_(True) batch['img'] = batch['img'].requires_grad_(True)
one2many = preds["one2many"] one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch) loss_one2many = self.one2many(one2many, batch)
@ -740,13 +741,28 @@ class v10PGTDetectLoss:
loss = loss_one2many[0] + loss_one2one[0] loss = loss_one2many[0] + loss_one2one[0]
if return_plaus: if return_plaus:
smask = get_dist_reg(batch['img'], batch['masks']) smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0] # graph = False if inference else True
grad = torch.abs(grad) # grad = torch.autograd.grad(loss, batch['img'],
# retain_graph=True,
# create_graph=graph,
# )[0]
try:
grad = torch.autograd.grad(loss, batch['img'],
retain_graph=True,
create_graph=True,
)[0]
except:
grad = torch.autograd.grad(loss, batch['img'],
retain_graph=True,
create_graph=False,
)[0]
pgt_coeff = 3.0
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff) grad = grad ** 2
plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0))) # self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
loss += plaus_loss loss += plaus_loss