mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-22 00:15:38 +08:00 
			
		
		
		
	Move check_amp() to checks.py (#2948)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									d07ba25dc4
								
							
						
					
					
						commit
						67cf53b475
					
				| @ -101,7 +101,7 @@ def test_val_scratch(): | ||||
| 
 | ||||
| def test_amp(): | ||||
|     if torch.cuda.is_available(): | ||||
|         from ultralytics.yolo.engine.trainer import check_amp | ||||
|         from ultralytics.yolo.utils.checks import check_amp | ||||
|         model = YOLO(MODEL).model.cuda() | ||||
|         assert check_amp(model) | ||||
| 
 | ||||
|  | ||||
| @ -24,10 +24,10 @@ from tqdm import tqdm | ||||
| from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__, | ||||
|                                     callbacks, clean_url, colorstr, emojis, yaml_save) | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, | ||||
|                                     clean_url, colorstr, emojis, yaml_save) | ||||
| from ultralytics.yolo.utils.autobatch import check_train_batch_size | ||||
| from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args | ||||
| from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args | ||||
| from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command | ||||
| from ultralytics.yolo.utils.files import get_latest_run, increment_path | ||||
| from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, | ||||
| @ -648,52 +648,3 @@ class BaseTrainer: | ||||
|         LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " | ||||
|                     f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') | ||||
|         return optimizer | ||||
| 
 | ||||
| 
 | ||||
| def check_amp(model): | ||||
|     """ | ||||
|     This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. | ||||
|     If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP | ||||
|     results, so AMP will be disabled during training. | ||||
| 
 | ||||
|     Args: | ||||
|         model (nn.Module): A YOLOv8 model instance. | ||||
| 
 | ||||
|     Returns: | ||||
|         (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. | ||||
| 
 | ||||
|     Raises: | ||||
|         AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system. | ||||
|     """ | ||||
|     device = next(model.parameters()).device  # get model device | ||||
|     if device.type in ('cpu', 'mps'): | ||||
|         return False  # AMP only used on CUDA devices | ||||
| 
 | ||||
|     def amp_allclose(m, im): | ||||
|         """All close FP32 vs AMP results.""" | ||||
|         a = m(im, device=device, verbose=False)[0].boxes.data  # FP32 inference | ||||
|         with torch.cuda.amp.autocast(True): | ||||
|             b = m(im, device=device, verbose=False)[0].boxes.data  # AMP inference | ||||
|         del m | ||||
|         return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5)  # close to 0.5 absolute tolerance | ||||
| 
 | ||||
|     f = ROOT / 'assets/bus.jpg'  # image to check | ||||
|     im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3)) | ||||
|     prefix = colorstr('AMP: ') | ||||
|     LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') | ||||
|     warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." | ||||
|     try: | ||||
|         from ultralytics import YOLO | ||||
|         assert amp_allclose(YOLO('yolov8n.pt'), im) | ||||
|         LOGGER.info(f'{prefix}checks passed ✅') | ||||
|     except ConnectionError: | ||||
|         LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') | ||||
|     except (AttributeError, ModuleNotFoundError): | ||||
|         LOGGER.warning( | ||||
|             f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}' | ||||
|         ) | ||||
|     except AssertionError: | ||||
|         LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' | ||||
|                        f'NaN losses or zero-mAP results, so AMP will be disabled during training.') | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| @ -344,6 +344,55 @@ def check_yolo(verbose=True, device=''): | ||||
|     LOGGER.info(f'Setup complete ✅ {s}') | ||||
| 
 | ||||
| 
 | ||||
| def check_amp(model): | ||||
|     """ | ||||
|     This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. | ||||
|     If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP | ||||
|     results, so AMP will be disabled during training. | ||||
| 
 | ||||
|     Args: | ||||
|         model (nn.Module): A YOLOv8 model instance. | ||||
| 
 | ||||
|     Returns: | ||||
|         (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. | ||||
| 
 | ||||
|     Raises: | ||||
|         AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system. | ||||
|     """ | ||||
|     device = next(model.parameters()).device  # get model device | ||||
|     if device.type in ('cpu', 'mps'): | ||||
|         return False  # AMP only used on CUDA devices | ||||
| 
 | ||||
|     def amp_allclose(m, im): | ||||
|         """All close FP32 vs AMP results.""" | ||||
|         a = m(im, device=device, verbose=False)[0].boxes.data  # FP32 inference | ||||
|         with torch.cuda.amp.autocast(True): | ||||
|             b = m(im, device=device, verbose=False)[0].boxes.data  # AMP inference | ||||
|         del m | ||||
|         return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5)  # close to 0.5 absolute tolerance | ||||
| 
 | ||||
|     f = ROOT / 'assets/bus.jpg'  # image to check | ||||
|     im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3)) | ||||
|     prefix = colorstr('AMP: ') | ||||
|     LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') | ||||
|     warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." | ||||
|     try: | ||||
|         from ultralytics import YOLO | ||||
|         assert amp_allclose(YOLO('yolov8n.pt'), im) | ||||
|         LOGGER.info(f'{prefix}checks passed ✅') | ||||
|     except ConnectionError: | ||||
|         LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') | ||||
|     except (AttributeError, ModuleNotFoundError): | ||||
|         LOGGER.warning( | ||||
|             f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}' | ||||
|         ) | ||||
|     except AssertionError: | ||||
|         LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' | ||||
|                        f'NaN losses or zero-mAP results, so AMP will be disabled during training.') | ||||
|         return False | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def git_describe(path=ROOT):  # path must be a directory | ||||
|     # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe | ||||
|     try: | ||||
|  | ||||
| @ -386,7 +386,7 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: | ||||
|         import pickle | ||||
| 
 | ||||
|     x = torch.load(f, map_location=torch.device('cpu')) | ||||
|     args = {**DEFAULT_CFG_DICT, **x['train_args']}  # combine model args with default args, preferring model args | ||||
|     args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None  # combine args | ||||
|     if x.get('ema'): | ||||
|         x['model'] = x['ema']  # replace model with ema | ||||
|     for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Ayush Chaurasia
						Ayush Chaurasia