mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
added pgt_coeff to argparser
This commit is contained in:
parent
21e660fde9
commit
1ef6cbcec5
@ -10,12 +10,16 @@ import torch
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = YOLOv10PGT('yolov10n.pt')
|
model = YOLOv10PGT('yolov10n.pt')
|
||||||
|
|
||||||
|
if args.pgt_coeff is None:
|
||||||
|
model.train(data=args.data_yaml, epochs=args.epochs, batch=args.batch_size)
|
||||||
|
else:
|
||||||
model.train(
|
model.train(
|
||||||
data=args.data_yaml,
|
data=args.data_yaml,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
batch=args.batch_size,
|
batch=args.batch_size,
|
||||||
# amp=False,
|
# amp=False,
|
||||||
pgt_coeff=3.0,
|
pgt_coeff=args.pgt_coeff,
|
||||||
# cfg='pgt_train.yaml', # Load and train model with the config file
|
# cfg='pgt_train.yaml', # Load and train model with the config file
|
||||||
)
|
)
|
||||||
# If you want to finetune the model with pretrained weights, you could load the
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
@ -24,20 +28,6 @@ def main(args):
|
|||||||
# or
|
# or
|
||||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
# model = YOLOv10('yolov10n.pt', task='segment')
|
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
# model = YOLOv10('yolov10n.pt', task='segment')
|
|
||||||
|
|
||||||
# args_dict = dict(
|
|
||||||
# model='yolov10n.pt',
|
|
||||||
# data=args.data_yaml,
|
|
||||||
# epochs=args.epochs, batch=args.batch_size,
|
|
||||||
# # pgt_coeff=5.0,
|
|
||||||
# # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
|
||||||
# )
|
|
||||||
# trainer = PGTSegmentationTrainer(overrides=args_dict)
|
|
||||||
# trainer.train(
|
|
||||||
# # debug=True,
|
|
||||||
# # args = dict(pgt_coeff=0.1), # Should add later to config
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Create a directory to save model weights if it doesn't exist
|
# Create a directory to save model weights if it doesn't exist
|
||||||
model_weights_dir = 'model_weights'
|
model_weights_dir = 'model_weights'
|
||||||
@ -63,6 +53,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
||||||
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||||
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
||||||
|
parser.add_argument('--pgt_coeff', type=float, default=None, help='Coefficient for PGT')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set CUDA device (only needed for multi-gpu machines)
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user