mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
updated grad to create_graph during training, enabling better pgt_loss optimization
This commit is contained in:
parent
411157c18a
commit
0efbfe7f4d
@ -5,6 +5,7 @@ from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
|||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# model = YOLOv10()
|
# model = YOLOv10()
|
||||||
@ -14,13 +15,16 @@ def main(args):
|
|||||||
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
# or
|
# or
|
||||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
model = YOLOv10('yolov10n.pt', task='segment')
|
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
|
|
||||||
args = dict(model='yolov10n.pt', data=args.data_yaml,
|
args_dict = dict(
|
||||||
|
model='yolov10n.pt',
|
||||||
|
data=args.data_yaml,
|
||||||
epochs=args.epochs, batch=args.batch_size,
|
epochs=args.epochs, batch=args.batch_size,
|
||||||
|
# pgt_coeff=5.0,
|
||||||
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||||
)
|
)
|
||||||
trainer = PGTSegmentationTrainer(overrides=args)
|
trainer = PGTSegmentationTrainer(overrides=args_dict)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
# debug=True,
|
# debug=True,
|
||||||
# args = dict(pgt_coeff=0.1), # Should add later to config
|
# args = dict(pgt_coeff=0.1), # Should add later to config
|
||||||
@ -35,10 +39,10 @@ def main(args):
|
|||||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
||||||
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
||||||
model.save(model_save_path)
|
trainer.model.save(model_save_path)
|
||||||
|
|
||||||
# Evaluate the model on the validation set
|
# Evaluate the model on the validation set
|
||||||
results = model.val(data='coco.yaml')
|
results = trainer.val(data=args.data_yaml)
|
||||||
|
|
||||||
# Print the evaluation results
|
# Print the evaluation results
|
||||||
print(results)
|
print(results)
|
||||||
@ -54,4 +58,5 @@ if __name__ == "__main__":
|
|||||||
# Set CUDA device (only needed for multi-gpu machines)
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||||
main(args)
|
main(args)
|
||||||
|
|
@ -727,11 +727,12 @@ class v10DetectLoss:
|
|||||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||||
|
|
||||||
class v10PGTDetectLoss:
|
class v10PGTDetectLoss:
|
||||||
def __init__(self, model):
|
def __init__(self, model, pgt_coeff=3.0):
|
||||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||||
|
self.pgt_coeff = pgt_coeff
|
||||||
|
|
||||||
def __call__(self, preds, batch, return_plaus=True):
|
def __call__(self, preds, batch, return_plaus=True, inference=False):
|
||||||
batch['img'] = batch['img'].requires_grad_(True)
|
batch['img'] = batch['img'].requires_grad_(True)
|
||||||
one2many = preds["one2many"]
|
one2many = preds["one2many"]
|
||||||
loss_one2many = self.one2many(one2many, batch)
|
loss_one2many = self.one2many(one2many, batch)
|
||||||
@ -740,13 +741,28 @@ class v10PGTDetectLoss:
|
|||||||
|
|
||||||
loss = loss_one2many[0] + loss_one2one[0]
|
loss = loss_one2many[0] + loss_one2one[0]
|
||||||
if return_plaus:
|
if return_plaus:
|
||||||
smask = get_dist_reg(batch['img'], batch['masks'])
|
smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
|
||||||
|
|
||||||
grad = torch.autograd.grad(loss, batch['img'], retain_graph=True)[0]
|
# graph = False if inference else True
|
||||||
grad = torch.abs(grad)
|
# grad = torch.autograd.grad(loss, batch['img'],
|
||||||
|
# retain_graph=True,
|
||||||
|
# create_graph=graph,
|
||||||
|
# )[0]
|
||||||
|
try:
|
||||||
|
grad = torch.autograd.grad(loss, batch['img'],
|
||||||
|
retain_graph=True,
|
||||||
|
create_graph=True,
|
||||||
|
)[0]
|
||||||
|
except:
|
||||||
|
grad = torch.autograd.grad(loss, batch['img'],
|
||||||
|
retain_graph=True,
|
||||||
|
create_graph=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
pgt_coeff = 3.0
|
|
||||||
plaus_loss = plaus_loss_fn(grad, smask, pgt_coeff)
|
grad = grad ** 2
|
||||||
|
|
||||||
|
plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
|
||||||
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||||
loss += plaus_loss
|
loss += plaus_loss
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user