mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Fix yolo mode=train CLI bug on model load (#133)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									c3d961fb03
								
							
						
					
					
						commit
						8f3cd52844
					
				
							
								
								
									
										59
									
								
								tests/tests_cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								tests/tests_cli.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,59 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from ultralytics.yolo.utils import ROOT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_checks():
 | 
			
		||||
    os.system('yolo mode=checks')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Train checks ---------------------------------------------------------------------------------------------------------
 | 
			
		||||
def test_train_detect():
 | 
			
		||||
    os.system('yolo mode=train task=detect model=yolov8n.yaml data=coco128.yaml imgsz=32 epochs=1')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_train_segment():
 | 
			
		||||
    os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_train_classify():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Val checks -----------------------------------------------------------------------------------------------------------
 | 
			
		||||
def test_val_detect():
 | 
			
		||||
    os.system('yolo mode=val task=detect model=yolov8n.pt data=coco128.yaml imgsz=32 epochs=1')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_val_segment():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_val_classify():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Predict checks -------------------------------------------------------------------------------------------------------
 | 
			
		||||
def test_predict_detect():
 | 
			
		||||
    os.system(f"yolo mode=predict model=yolov8n.pt source={ROOT / 'assets'}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_predict_segment():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_predict_classify():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Export checks --------------------------------------------------------------------------------------------------------
 | 
			
		||||
def test_export_detect_torchscript():
 | 
			
		||||
    os.system('yolo mode=export model=yolov8n.pt format=torchscript')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_export_segment_torchscript():
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_export_classify_torchscript():
 | 
			
		||||
    pass
 | 
			
		||||
@ -5,11 +5,10 @@ import time
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from ultralytics.hub.config import HUB_API_ROOT
 | 
			
		||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
 | 
			
		||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
 | 
			
		||||
 | 
			
		||||
PREFIX = colorstr('Ultralytics: ')
 | 
			
		||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
 | 
			
		||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
 | 
			
		||||
 | 
			
		||||
@ -10,13 +10,11 @@ import torchvision
 | 
			
		||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
 | 
			
		||||
                                    Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
 | 
			
		||||
                                    GhostBottleneck, GhostConv, Segment)
 | 
			
		||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
 | 
			
		||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
 | 
			
		||||
from ultralytics.yolo.utils.checks import check_yaml
 | 
			
		||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
 | 
			
		||||
                                                model_info, scale_img, time_sync)
 | 
			
		||||
 | 
			
		||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseModel(nn.Module):
 | 
			
		||||
    '''
 | 
			
		||||
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
 | 
			
		||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
 | 
			
		||||
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
 | 
			
		||||
    from ultralytics.yolo.utils.downloads import attempt_download
 | 
			
		||||
    default_keys = DEFAULT_CONFIG_DICT.keys()
 | 
			
		||||
 | 
			
		||||
    model = Ensemble()
 | 
			
		||||
    for w in weights if isinstance(weights, list) else [weights]:
 | 
			
		||||
        ckpt = torch.load(attempt_download(w), map_location='cpu')  # load
 | 
			
		||||
        args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
 | 
			
		||||
        args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model args
 | 
			
		||||
        ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model
 | 
			
		||||
 | 
			
		||||
        # Model compatibility updates
 | 
			
		||||
        ckpt.args = {k: v for k, v in args.items() if k in default_keys}
 | 
			
		||||
        ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
 | 
			
		||||
 | 
			
		||||
        # Append
 | 
			
		||||
        model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval())  # model in eval mode
 | 
			
		||||
 | 
			
		||||
@ -362,7 +362,7 @@ class BaseTrainer:
 | 
			
		||||
            return
 | 
			
		||||
            # We should improve the code flow here. This function looks hacky
 | 
			
		||||
        model = self.model
 | 
			
		||||
        pretrained = not (str(model).endswith(".yaml"))
 | 
			
		||||
        pretrained = not str(model).endswith(".yaml")
 | 
			
		||||
        # config
 | 
			
		||||
        if not pretrained:
 | 
			
		||||
            model = check_file(model)
 | 
			
		||||
 | 
			
		||||
@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
 | 
			
		||||
cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 | 
			
		||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads
 | 
			
		||||
 | 
			
		||||
# Default config dictionary
 | 
			
		||||
with open(DEFAULT_CONFIG, errors='ignore') as f:
 | 
			
		||||
    DEFAULT_CONFIG_DICT = yaml.safe_load(f)
 | 
			
		||||
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_colab():
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ import torch.nn.functional as F
 | 
			
		||||
from torch.nn.parallel import DistributedDataParallel as DDP
 | 
			
		||||
 | 
			
		||||
import ultralytics
 | 
			
		||||
from ultralytics.yolo.utils import LOGGER
 | 
			
		||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
 | 
			
		||||
from ultralytics.yolo.utils.checks import git_describe
 | 
			
		||||
 | 
			
		||||
from .checks import check_version
 | 
			
		||||
@ -270,6 +270,7 @@ class ModelEMA:
 | 
			
		||||
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
 | 
			
		||||
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
 | 
			
		||||
    x = torch.load(f, map_location=torch.device('cpu'))
 | 
			
		||||
    args = {**DEFAULT_CONFIG_DICT, **x['train_args']}  # combine model args with default args, preferring model args
 | 
			
		||||
    if x.get('ema'):
 | 
			
		||||
        x['model'] = x['ema']  # replace model with ema
 | 
			
		||||
    for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys
 | 
			
		||||
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_op
 | 
			
		||||
    x['model'].half()  # to FP16
 | 
			
		||||
    for p in x['model'].parameters():
 | 
			
		||||
        p.requires_grad = False
 | 
			
		||||
    x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}  # strip non-default keys
 | 
			
		||||
    torch.save(x, s or f)
 | 
			
		||||
    mb = os.path.getsize(s or f) / 1E6  # filesize
 | 
			
		||||
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
 | 
			
		||||
 | 
			
		||||
@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
 | 
			
		||||
        self.model.names = self.data["names"]
 | 
			
		||||
 | 
			
		||||
    def load_model(self, model_cfg=None, weights=None, verbose=True):
 | 
			
		||||
        model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
 | 
			
		||||
        model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
 | 
			
		||||
                               ch=3,
 | 
			
		||||
                               nc=self.data["nc"],
 | 
			
		||||
                               verbose=verbose)
 | 
			
		||||
        if weights:
 | 
			
		||||
            model.load(weights, verbose)
 | 
			
		||||
            model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def get_validator(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user