RTDETRDetectionModel TorchScript, ONNX Predict and Val support (#8818)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-09 23:25:01 +01:00 committed by GitHub
parent 911d18e4f3
commit af6c02c39b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 52 additions and 7 deletions

View File

@ -118,7 +118,7 @@ jobs:
run: | run: |
yolo checks yolo checks
pip list pip list
- name: Benchmark World DetectionModel - name: Benchmark YOLOWorld DetectionModel
shell: bash shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318 run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel - name: Benchmark SegmentationModel

View File

@ -65,8 +65,8 @@ def test_export(model, format):
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
"""Test the RTDETR functionality with the Ultralytics framework.""" """Test the RTDETR functionality with the Ultralytics framework."""
# Warning: MUST use imgsz=640 # Warning: MUST use imgsz=640
run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt") run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12") @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")

View File

@ -36,8 +36,6 @@ class RTDETR(Model):
Raises: Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
""" """
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
super().__init__(model=model, task="detect") super().__init__(model=model, task="detect")
@property @property

View File

@ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor):
The method filters detections based on confidence and class if specified in `self.args`. The method filters detections based on confidence and class if specified in `self.args`.
Args: Args:
preds (torch.Tensor): Raw predictions from the model. preds (list): List of [predictions, extra] from the model.
img (torch.Tensor): Processed input images. img (torch.Tensor): Processed input images.
orig_imgs (list or torch.Tensor): Original, unprocessed images. orig_imgs (list or torch.Tensor): Original, unprocessed images.
@ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor):
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
and class labels. and class labels.
""" """
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
nd = preds[0].shape[-1] nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1) bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

View File

@ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator):
def postprocess(self, preds): def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs.""" """Apply Non-maximum suppression to prediction outputs."""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
bs, _, nd = preds[0].shape bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1) bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz bboxes *= self.args.imgsz

View File

@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True):
downloads.safe_download(url=url, file=file, unzip=False) downloads.safe_download(url=url, file=file, unzip=False)
return file return file
else: # search else: # search
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard: if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist") raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard: elif len(files) > 1 and hard:

View File

@ -145,3 +145,44 @@ def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" """Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else "" return max(last_list, key=os.path.getctime) if last_list else ""
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
"""
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
update_names (bool, optional): Update model names from a data YAML.
Example:
```python
from ultralytics.utils.files import update_models
model_names = (f"rtdetr-{size}.pt" for size in "lx")
update_models(model_names)
```
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names
target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")
# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")
# Define new save path
save_path = target_dir / model_name
# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path, use_dill=False)