mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
changed save path to not override previous training
This commit is contained in:
parent
3a449d5a6c
commit
5953d3c9c6
@ -3,6 +3,7 @@ from ultralytics import YOLOv10, YOLO
|
||||
import os
|
||||
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -15,7 +16,7 @@ def main(args):
|
||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||
model = YOLOv10('yolov10n.pt', task='segment')
|
||||
|
||||
args = dict(model='yolov10n.pt', data='coco128-seg.yaml',
|
||||
args = dict(model='yolov10n.pt', data=args.data_yaml,
|
||||
epochs=args.epochs, batch=args.batch_size,
|
||||
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||
)
|
||||
@ -25,8 +26,16 @@ def main(args):
|
||||
# args = dict(pgt_coeff=0.1), # Should add later to config
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
model.save('yolov10_coco_trained.pt')
|
||||
# Create a directory to save model weights if it doesn't exist
|
||||
model_weights_dir = 'model_weights'
|
||||
if not os.path.exists(model_weights_dir):
|
||||
os.makedirs(model_weights_dir)
|
||||
|
||||
# Save the trained model with a unique name based on the current date and time
|
||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
||||
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
||||
model.save(model_save_path)
|
||||
|
||||
# Evaluate the model on the validation set
|
||||
results = model.val(data='coco.yaml')
|
||||
@ -39,6 +48,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||
parser.add_argument('--data_yaml', type=str, required=True, default='coco.yaml', help='Path to the data YAML file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set CUDA device (only needed for multi-gpu machines)
|
||||
|
Loading…
x
Reference in New Issue
Block a user