diff --git a/run_pgt_train.py b/run_pgt_train.py index 9bb8aae6..a89763a4 100644 --- a/run_pgt_train.py +++ b/run_pgt_train.py @@ -10,34 +10,24 @@ import torch def main(args): model = YOLOv10PGT('yolov10n.pt') - model.train( - data=args.data_yaml, - epochs=args.epochs, - batch=args.batch_size, - # amp=False, - pgt_coeff=3.0, - # cfg='pgt_train.yaml', # Load and train model with the config file - ) + + if args.pgt_coeff is None: + model.train(data=args.data_yaml, epochs=args.epochs, batch=args.batch_size) + else: + model.train( + data=args.data_yaml, + epochs=args.epochs, + batch=args.batch_size, + # amp=False, + pgt_coeff=args.pgt_coeff, + # cfg='pgt_train.yaml', # Load and train model with the config file + ) # If you want to finetune the model with pretrained weights, you could load the # pretrained weights like below # model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') # or # wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt # model = YOLOv10('yolov10n.pt', task='segment') - # model = YOLOv10('yolov10n.pt', task='segment') - - # args_dict = dict( - # model='yolov10n.pt', - # data=args.data_yaml, - # epochs=args.epochs, batch=args.batch_size, - # # pgt_coeff=5.0, - # # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process - # ) - # trainer = PGTSegmentationTrainer(overrides=args_dict) - # trainer.train( - # # debug=True, - # # args = dict(pgt_coeff=0.1), # Should add later to config - # ) # Create a directory to save model weights if it doesn't exist model_weights_dir = 'model_weights' @@ -63,6 +53,7 @@ if __name__ == "__main__": parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training') parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file') + parser.add_argument('--pgt_coeff', type=float, default=None, help='Coefficient for PGT') args = parser.parse_args() # Set CUDA device (only needed for multi-gpu machines)