mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 14:35:40 +08:00 
			
		
		
		
	RTDETRDetectionModel TorchScript, ONNX Predict and Val support (#8818)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									911d18e4f3
								
							
						
					
					
						commit
						af6c02c39b
					
				
							
								
								
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -118,7 +118,7 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           yolo checks |           yolo checks | ||||||
|           pip list |           pip list | ||||||
|       - name: Benchmark World DetectionModel |       - name: Benchmark YOLOWorld DetectionModel | ||||||
|         shell: bash |         shell: bash | ||||||
|         run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318 |         run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318 | ||||||
|       - name: Benchmark SegmentationModel |       - name: Benchmark SegmentationModel | ||||||
|  | |||||||
| @ -65,8 +65,8 @@ def test_export(model, format): | |||||||
| def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): | def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): | ||||||
|     """Test the RTDETR functionality with the Ultralytics framework.""" |     """Test the RTDETR functionality with the Ultralytics framework.""" | ||||||
|     # Warning: MUST use imgsz=640 |     # Warning: MUST use imgsz=640 | ||||||
|     run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk")  # add coma, spaces to args |     run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk")  # add coma, spaces to args | ||||||
|     run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt") |     run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12") | @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12") | ||||||
|  | |||||||
| @ -36,8 +36,6 @@ class RTDETR(Model): | |||||||
|         Raises: |         Raises: | ||||||
|             NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. |             NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. | ||||||
|         """ |         """ | ||||||
|         if model and Path(model).suffix not in (".pt", ".yaml", ".yml"): |  | ||||||
|             raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.") |  | ||||||
|         super().__init__(model=model, task="detect") |         super().__init__(model=model, task="detect") | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor): | |||||||
|         The method filters detections based on confidence and class if specified in `self.args`. |         The method filters detections based on confidence and class if specified in `self.args`. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             preds (torch.Tensor): Raw predictions from the model. |             preds (list): List of [predictions, extra] from the model. | ||||||
|             img (torch.Tensor): Processed input images. |             img (torch.Tensor): Processed input images. | ||||||
|             orig_imgs (list or torch.Tensor): Original, unprocessed images. |             orig_imgs (list or torch.Tensor): Original, unprocessed images. | ||||||
| 
 | 
 | ||||||
| @ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor): | |||||||
|             (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, |             (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, | ||||||
|                 and class labels. |                 and class labels. | ||||||
|         """ |         """ | ||||||
|  |         if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference | ||||||
|  |             preds = [preds, None] | ||||||
|  | 
 | ||||||
|         nd = preds[0].shape[-1] |         nd = preds[0].shape[-1] | ||||||
|         bboxes, scores = preds[0].split((4, nd - 4), dim=-1) |         bboxes, scores = preds[0].split((4, nd - 4), dim=-1) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator): | |||||||
| 
 | 
 | ||||||
|     def postprocess(self, preds): |     def postprocess(self, preds): | ||||||
|         """Apply Non-maximum suppression to prediction outputs.""" |         """Apply Non-maximum suppression to prediction outputs.""" | ||||||
|  |         if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference | ||||||
|  |             preds = [preds, None] | ||||||
|  | 
 | ||||||
|         bs, _, nd = preds[0].shape |         bs, _, nd = preds[0].shape | ||||||
|         bboxes, scores = preds[0].split((4, nd - 4), dim=-1) |         bboxes, scores = preds[0].split((4, nd - 4), dim=-1) | ||||||
|         bboxes *= self.args.imgsz |         bboxes *= self.args.imgsz | ||||||
|  | |||||||
| @ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True): | |||||||
|             downloads.safe_download(url=url, file=file, unzip=False) |             downloads.safe_download(url=url, file=file, unzip=False) | ||||||
|         return file |         return file | ||||||
|     else:  # search |     else:  # search | ||||||
|         files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True)  # find file |         files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file))  # find file | ||||||
|         if not files and hard: |         if not files and hard: | ||||||
|             raise FileNotFoundError(f"'{file}' does not exist") |             raise FileNotFoundError(f"'{file}' does not exist") | ||||||
|         elif len(files) > 1 and hard: |         elif len(files) > 1 and hard: | ||||||
|  | |||||||
| @ -145,3 +145,44 @@ def get_latest_run(search_dir="."): | |||||||
|     """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" |     """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" | ||||||
|     last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) |     last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) | ||||||
|     return max(last_list, key=os.path.getctime) if last_list else "" |     return max(last_list, key=os.path.getctime) if last_list else "" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False): | ||||||
|  |     """ | ||||||
|  |     Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt"). | ||||||
|  |         source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory. | ||||||
|  |         update_names (bool, optional): Update model names from a data YAML. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         from ultralytics.utils.files import update_models | ||||||
|  | 
 | ||||||
|  |         model_names = (f"rtdetr-{size}.pt" for size in "lx") | ||||||
|  |         update_models(model_names) | ||||||
|  |         ``` | ||||||
|  |     """ | ||||||
|  |     from ultralytics import YOLO | ||||||
|  |     from ultralytics.nn.autobackend import default_class_names | ||||||
|  | 
 | ||||||
|  |     target_dir = source_dir / "updated_models" | ||||||
|  |     target_dir.mkdir(parents=True, exist_ok=True)  # Ensure target directory exists | ||||||
|  | 
 | ||||||
|  |     for model_name in model_names: | ||||||
|  |         model_path = source_dir / model_name | ||||||
|  |         print(f"Loading model from {model_path}") | ||||||
|  | 
 | ||||||
|  |         # Load model | ||||||
|  |         model = YOLO(model_path) | ||||||
|  |         model.half() | ||||||
|  |         if update_names:  # update model names from a dataset YAML | ||||||
|  |             model.model.names = default_class_names("coco8.yaml") | ||||||
|  | 
 | ||||||
|  |         # Define new save path | ||||||
|  |         save_path = target_dir / model_name | ||||||
|  | 
 | ||||||
|  |         # Save model using model.save() | ||||||
|  |         print(f"Re-saving {model_name} model to {save_path}") | ||||||
|  |         model.save(save_path, use_dill=False) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher