mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 00:45:38 +08:00 
			
		
		
		
	Fix yolo mode=train CLI bug on model load (#133)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									c3d961fb03
								
							
						
					
					
						commit
						8f3cd52844
					
				
							
								
								
									
										59
									
								
								tests/tests_cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								tests/tests_cli.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,59 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ultralytics.yolo.utils import ROOT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_checks():
 | 
				
			||||||
 | 
					    os.system('yolo mode=checks')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Train checks ---------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					def test_train_detect():
 | 
				
			||||||
 | 
					    os.system('yolo mode=train task=detect model=yolov8n.yaml data=coco128.yaml imgsz=32 epochs=1')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_train_segment():
 | 
				
			||||||
 | 
					    os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_train_classify():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Val checks -----------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					def test_val_detect():
 | 
				
			||||||
 | 
					    os.system('yolo mode=val task=detect model=yolov8n.pt data=coco128.yaml imgsz=32 epochs=1')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_val_segment():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_val_classify():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Predict checks -------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					def test_predict_detect():
 | 
				
			||||||
 | 
					    os.system(f"yolo mode=predict model=yolov8n.pt source={ROOT / 'assets'}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_predict_segment():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_predict_classify():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Export checks --------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					def test_export_detect_torchscript():
 | 
				
			||||||
 | 
					    os.system('yolo mode=export model=yolov8n.pt format=torchscript')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_export_segment_torchscript():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_export_classify_torchscript():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
@ -5,11 +5,10 @@ import time
 | 
				
			|||||||
import requests
 | 
					import requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.hub.config import HUB_API_ROOT
 | 
					from ultralytics.hub.config import HUB_API_ROOT
 | 
				
			||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
 | 
					from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PREFIX = colorstr('Ultralytics: ')
 | 
					PREFIX = colorstr('Ultralytics: ')
 | 
				
			||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
 | 
					HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
 | 
				
			||||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
 | 
					def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
 | 
				
			||||||
 | 
				
			|||||||
@ -10,13 +10,11 @@ import torchvision
 | 
				
			|||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
 | 
					from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
 | 
				
			||||||
                                    Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
 | 
					                                    Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
 | 
				
			||||||
                                    GhostBottleneck, GhostConv, Segment)
 | 
					                                    GhostBottleneck, GhostConv, Segment)
 | 
				
			||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
 | 
					from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
 | 
				
			||||||
from ultralytics.yolo.utils.checks import check_yaml
 | 
					from ultralytics.yolo.utils.checks import check_yaml
 | 
				
			||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
 | 
					from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
 | 
				
			||||||
                                                model_info, scale_img, time_sync)
 | 
					                                                model_info, scale_img, time_sync)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseModel(nn.Module):
 | 
					class BaseModel(nn.Module):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
 | 
				
			|||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
 | 
					def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
 | 
				
			||||||
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
 | 
					    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
 | 
				
			||||||
    from ultralytics.yolo.utils.downloads import attempt_download
 | 
					    from ultralytics.yolo.utils.downloads import attempt_download
 | 
				
			||||||
    default_keys = DEFAULT_CONFIG_DICT.keys()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = Ensemble()
 | 
					    model = Ensemble()
 | 
				
			||||||
    for w in weights if isinstance(weights, list) else [weights]:
 | 
					    for w in weights if isinstance(weights, list) else [weights]:
 | 
				
			||||||
        ckpt = torch.load(attempt_download(w), map_location='cpu')  # load
 | 
					        ckpt = torch.load(attempt_download(w), map_location='cpu')  # load
 | 
				
			||||||
        args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
 | 
					        args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model args
 | 
				
			||||||
        ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model
 | 
					        ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Model compatibility updates
 | 
					        # Model compatibility updates
 | 
				
			||||||
        ckpt.args = {k: v for k, v in args.items() if k in default_keys}
 | 
					        ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Append
 | 
					        # Append
 | 
				
			||||||
        model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval())  # model in eval mode
 | 
					        model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval())  # model in eval mode
 | 
				
			||||||
 | 
				
			|||||||
@ -362,7 +362,7 @@ class BaseTrainer:
 | 
				
			|||||||
            return
 | 
					            return
 | 
				
			||||||
            # We should improve the code flow here. This function looks hacky
 | 
					            # We should improve the code flow here. This function looks hacky
 | 
				
			||||||
        model = self.model
 | 
					        model = self.model
 | 
				
			||||||
        pretrained = not (str(model).endswith(".yaml"))
 | 
					        pretrained = not str(model).endswith(".yaml")
 | 
				
			||||||
        # config
 | 
					        # config
 | 
				
			||||||
        if not pretrained:
 | 
					        if not pretrained:
 | 
				
			||||||
            model = check_file(model)
 | 
					            model = check_file(model)
 | 
				
			||||||
 | 
				
			|||||||
@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
 | 
				
			|||||||
cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 | 
					cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 | 
				
			||||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads
 | 
					os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Default config dictionary
 | 
				
			||||||
 | 
					with open(DEFAULT_CONFIG, errors='ignore') as f:
 | 
				
			||||||
 | 
					    DEFAULT_CONFIG_DICT = yaml.safe_load(f)
 | 
				
			||||||
 | 
					DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_colab():
 | 
					def is_colab():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,7 @@ import torch.nn.functional as F
 | 
				
			|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
 | 
					from torch.nn.parallel import DistributedDataParallel as DDP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import ultralytics
 | 
					import ultralytics
 | 
				
			||||||
from ultralytics.yolo.utils import LOGGER
 | 
					from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
 | 
				
			||||||
from ultralytics.yolo.utils.checks import git_describe
 | 
					from ultralytics.yolo.utils.checks import git_describe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .checks import check_version
 | 
					from .checks import check_version
 | 
				
			||||||
@ -270,6 +270,7 @@ class ModelEMA:
 | 
				
			|||||||
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
 | 
					def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
 | 
				
			||||||
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
 | 
					    # Strip optimizer from 'f' to finalize training, optionally save as 's'
 | 
				
			||||||
    x = torch.load(f, map_location=torch.device('cpu'))
 | 
					    x = torch.load(f, map_location=torch.device('cpu'))
 | 
				
			||||||
 | 
					    args = {**DEFAULT_CONFIG_DICT, **x['train_args']}  # combine model args with default args, preferring model args
 | 
				
			||||||
    if x.get('ema'):
 | 
					    if x.get('ema'):
 | 
				
			||||||
        x['model'] = x['ema']  # replace model with ema
 | 
					        x['model'] = x['ema']  # replace model with ema
 | 
				
			||||||
    for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys
 | 
					    for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys
 | 
				
			||||||
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_op
 | 
				
			|||||||
    x['model'].half()  # to FP16
 | 
					    x['model'].half()  # to FP16
 | 
				
			||||||
    for p in x['model'].parameters():
 | 
					    for p in x['model'].parameters():
 | 
				
			||||||
        p.requires_grad = False
 | 
					        p.requires_grad = False
 | 
				
			||||||
 | 
					    x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}  # strip non-default keys
 | 
				
			||||||
    torch.save(x, s or f)
 | 
					    torch.save(x, s or f)
 | 
				
			||||||
    mb = os.path.getsize(s or f) / 1E6  # filesize
 | 
					    mb = os.path.getsize(s or f) / 1E6  # filesize
 | 
				
			||||||
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
 | 
					    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
 | 
				
			||||||
 | 
				
			|||||||
@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
 | 
				
			|||||||
        self.model.names = self.data["names"]
 | 
					        self.model.names = self.data["names"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_model(self, model_cfg=None, weights=None, verbose=True):
 | 
					    def load_model(self, model_cfg=None, weights=None, verbose=True):
 | 
				
			||||||
        model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
 | 
					        model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
 | 
				
			||||||
 | 
					                               ch=3,
 | 
				
			||||||
 | 
					                               nc=self.data["nc"],
 | 
				
			||||||
 | 
					                               verbose=verbose)
 | 
				
			||||||
        if weights:
 | 
					        if weights:
 | 
				
			||||||
            model.load(weights, verbose)
 | 
					            model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_validator(self):
 | 
					    def get_validator(self):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user