mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
updated grad to create_graph during training, enabling better pgt_loss optimization
This commit is contained in:
parent
411157c18a
commit
0efbfe7f4d
@ -5,6 +5,7 @@ from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
|
||||
|
||||
def main(args):
|
||||
# model = YOLOv10()
|
||||
@ -14,13 +15,16 @@ def main(args):
|
||||
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||
# or
|
||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||
model = YOLOv10('yolov10n.pt', task='segment')
|
||||
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||
|
||||
args = dict(model='yolov10n.pt', data=args.data_yaml,
|
||||
args_dict = dict(
|
||||
model='yolov10n.pt',
|
||||
data=args.data_yaml,
|
||||
epochs=args.epochs, batch=args.batch_size,
|
||||
# pgt_coeff=5.0,
|
||||
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||
)
|
||||
trainer = PGTSegmentationTrainer(overrides=args)
|
||||
trainer = PGTSegmentationTrainer(overrides=args_dict)
|
||||
trainer.train(
|
||||
# debug=True,
|
||||
# args = dict(pgt_coeff=0.1), # Should add later to config
|
||||
@ -35,10 +39,10 @@ def main(args):
|
||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
||||
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
||||
model.save(model_save_path)
|
||||
trainer.model.save(model_save_path)
|
||||
|
||||
# Evaluate the model on the validation set
|
||||
results = model.val(data='coco.yaml')
|
||||
results = trainer.val(data=args.data_yaml)
|
||||
|
||||
# Print the evaluation results
|
||||
print(results)
|
||||
@ -55,3 +59,4 @@ if __name__ == "__main__":
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||
main(args)
|
||||
|
@ -727,11 +727,12 @@ class v10DetectLoss:
|
||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||
|
||||
class v10PGTDetectLoss:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, pgt_coeff=3.0):
|
||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||
self.pgt_coeff = pgt_coeff
|
||||
|
||||
def __call__(self, preds, batch, return_plaus=True):
|
||||
def __call__(self, preds, batch, return_plaus=True, inference=False):
|
||||
batch['img'] = batch['img'].requires_grad_(True)
|
||||
one2many = preds["one2many"]
|
||||
loss_one2many = self.one2many(one2many, batch)
|
||||
@ -740,13 +741,28 @@ class v10PGTDetectLoss:
|
||||
|
||||
loss = loss_one2many[0] + loss_one2one[0]
|
||||
if return_plaus:
|
||||
smask = get_dist_reg(batch['img'], batch['masks'])
|
||||
smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
|
||||
|
||||
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
|
||||
grad = torch.abs(grad)
|
||||
# graph = False if inference else True
|
||||
# grad = torch.autograd.grad(loss, batch['img'],
|
||||
# retain_graph=True,
|
||||
# create_graph=graph,
|
||||
# )[0]
|
||||
try:
|
||||
grad = torch.autograd.grad(loss, batch['img'],
|
||||
retain_graph=True,
|
||||
create_graph=True,
|
||||
)[0]
|
||||
except:
|
||||
grad = torch.autograd.grad(loss, batch['img'],
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
)[0]
|
||||
|
||||
pgt_coeff = 3.0
|
||||
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
|
||||
|
||||
grad = grad ** 2
|
||||
|
||||
plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
|
||||
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||
loss += plaus_loss
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user