From af6c02c39be4ee30e0119cc24468912257a3b529 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 9 Mar 2024 23:25:01 +0100 Subject: [PATCH] `RTDETRDetectionModel` TorchScript, ONNX Predict and Val support (#8818) Signed-off-by: Glenn Jocher --- .github/workflows/ci.yaml | 2 +- tests/test_cli.py | 4 +-- ultralytics/models/rtdetr/model.py | 2 -- ultralytics/models/rtdetr/predict.py | 5 +++- ultralytics/models/rtdetr/val.py | 3 ++ ultralytics/utils/checks.py | 2 +- ultralytics/utils/files.py | 41 ++++++++++++++++++++++++++++ 7 files changed, 52 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9e1b2b08..4eadc3ba 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -118,7 +118,7 @@ jobs: run: | yolo checks pip list - - name: Benchmark World DetectionModel + - name: Benchmark YOLOWorld DetectionModel shell: bash run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318 - name: Benchmark SegmentationModel diff --git a/tests/test_cli.py b/tests/test_cli.py index 4d1de053..fe257952 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -65,8 +65,8 @@ def test_export(model, format): def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): """Test the RTDETR functionality with the Ultralytics framework.""" # Warning: MUST use imgsz=640 - run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args - run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt") + run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args + run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12") diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index 362edce0..68d43be7 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -36,8 +36,6 @@ class RTDETR(Model): Raises: NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. """ - if model and Path(model).suffix not in (".pt", ".yaml", ".yml"): - raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.") super().__init__(model=model, task="detect") @property diff --git a/ultralytics/models/rtdetr/predict.py b/ultralytics/models/rtdetr/predict.py index 8ad92de8..7fc918b0 100644 --- a/ultralytics/models/rtdetr/predict.py +++ b/ultralytics/models/rtdetr/predict.py @@ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor): The method filters detections based on confidence and class if specified in `self.args`. Args: - preds (torch.Tensor): Raw predictions from the model. + preds (list): List of [predictions, extra] from the model. img (torch.Tensor): Processed input images. orig_imgs (list or torch.Tensor): Original, unprocessed images. @@ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor): (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, and class labels. """ + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + nd = preds[0].shape[-1] bboxes, scores = preds[0].split((4, nd - 4), dim=-1) diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py index 5f38f817..88bb0aee 100644 --- a/ultralytics/models/rtdetr/val.py +++ b/ultralytics/models/rtdetr/val.py @@ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator): def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + bs, _, nd = preds[0].shape bboxes, scores = preds[0].split((4, nd - 4), dim=-1) bboxes *= self.args.imgsz diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index 1ab031ba..106fa4e1 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True): downloads.safe_download(url=url, file=file, unzip=False) return file else: # search - files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file if not files and hard: raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py index ae8e90c2..719cacae 100644 --- a/ultralytics/utils/files.py +++ b/ultralytics/utils/files.py @@ -145,3 +145,44 @@ def get_latest_run(search_dir="."): """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) return max(last_list, key=os.path.getctime) if last_list else "" + + +def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False): + """ + Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. + + Args: + model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt"). + source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory. + update_names (bool, optional): Update model names from a data YAML. + + Example: + ```python + from ultralytics.utils.files import update_models + + model_names = (f"rtdetr-{size}.pt" for size in "lx") + update_models(model_names) + ``` + """ + from ultralytics import YOLO + from ultralytics.nn.autobackend import default_class_names + + target_dir = source_dir / "updated_models" + target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists + + for model_name in model_names: + model_path = source_dir / model_name + print(f"Loading model from {model_path}") + + # Load model + model = YOLO(model_path) + model.half() + if update_names: # update model names from a dataset YAML + model.model.names = default_class_names("coco8.yaml") + + # Define new save path + save_path = target_dir / model_name + + # Save model using model.save() + print(f"Re-saving {model_name} model to {save_path}") + model.save(save_path, use_dill=False)