mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Global settings typechecking (#148)
This commit is contained in:
		
							parent
							
								
									19334ebb16
								
							
						
					
					
						commit
						3cbf3ec455
					
				@ -48,7 +48,6 @@ class YOLO:
 | 
				
			|||||||
        self.ckpt_path = None
 | 
					        self.ckpt_path = None
 | 
				
			||||||
        self.cfg = None  # if loaded from *.yaml
 | 
					        self.cfg = None  # if loaded from *.yaml
 | 
				
			||||||
        self.overrides = {}  # overrides for trainer object
 | 
					        self.overrides = {}  # overrides for trainer object
 | 
				
			||||||
        self.init_disabled = False  # disable model initialization
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Load or create new YOLO model
 | 
					        # Load or create new YOLO model
 | 
				
			||||||
        {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
 | 
					        {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
 | 
				
			||||||
 | 
				
			|||||||
@ -365,8 +365,15 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
 | 
				
			|||||||
            yaml_save(file, defaults)
 | 
					            yaml_save(file, defaults)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        settings = yaml_load(file)
 | 
					        settings = yaml_load(file)
 | 
				
			||||||
        if settings.keys() != defaults.keys():
 | 
					
 | 
				
			||||||
            settings = {**defaults, **settings}  # merge **defaults with **settings (prefer **settings)
 | 
					        # Check that settings keys and types match defaults
 | 
				
			||||||
 | 
					        correct = settings.keys() == defaults.keys() and \
 | 
				
			||||||
 | 
					                  all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values()))
 | 
				
			||||||
 | 
					        if not correct:
 | 
				
			||||||
 | 
					            LOGGER.warning('WARNING ⚠️ Different global settings detected, resetting to defaults. '
 | 
				
			||||||
 | 
					                           'This may be due to an ultralytics package update. '
 | 
				
			||||||
 | 
					                           f'View and update your global settings directly in {file}')
 | 
				
			||||||
 | 
					            settings = defaults  # merge **defaults with **settings (prefer **settings)
 | 
				
			||||||
            yaml_save(file, settings)  # save updated defaults
 | 
					            yaml_save(file, settings)  # save updated defaults
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return settings
 | 
					        return settings
 | 
				
			||||||
 | 
				
			|||||||
@ -268,8 +268,23 @@ class ModelEMA:
 | 
				
			|||||||
        copy_attr(self.ema, model, include, exclude)
 | 
					        copy_attr(self.ema, model, include, exclude)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
 | 
					def strip_optimizer(f='best.pt', s=''):
 | 
				
			||||||
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
 | 
					    """
 | 
				
			||||||
 | 
					    Strip optimizer from 'f' to finalize training, optionally save as 's'.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Usage:
 | 
				
			||||||
 | 
					        from ultralytics.yolo.utils.torch_utils import strip_optimizer
 | 
				
			||||||
 | 
					        from pathlib import Path
 | 
				
			||||||
 | 
					        for f in Path('/Users/glennjocher/Downloads/weights').glob('*.pt'):
 | 
				
			||||||
 | 
					            strip_optimizer(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        f (str): file path to model state to strip the optimizer from. Default is 'best.pt'.
 | 
				
			||||||
 | 
					        s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        None
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    x = torch.load(f, map_location=torch.device('cpu'))
 | 
					    x = torch.load(f, map_location=torch.device('cpu'))
 | 
				
			||||||
    args = {**DEFAULT_CONFIG_DICT, **x['train_args']}  # combine model args with default args, preferring model args
 | 
					    args = {**DEFAULT_CONFIG_DICT, **x['train_args']}  # combine model args with default args, preferring model args
 | 
				
			||||||
    if x.get('ema'):
 | 
					    if x.get('ema'):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user