mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
RTDETRDetectionModel
TorchScript, ONNX Predict and Val support (#8818)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
911d18e4f3
commit
af6c02c39b
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -118,7 +118,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
yolo checks
|
yolo checks
|
||||||
pip list
|
pip list
|
||||||
- name: Benchmark World DetectionModel
|
- name: Benchmark YOLOWorld DetectionModel
|
||||||
shell: bash
|
shell: bash
|
||||||
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
|
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
|
||||||
- name: Benchmark SegmentationModel
|
- name: Benchmark SegmentationModel
|
||||||
|
@ -65,8 +65,8 @@ def test_export(model, format):
|
|||||||
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
|
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
|
||||||
"""Test the RTDETR functionality with the Ultralytics framework."""
|
"""Test the RTDETR functionality with the Ultralytics framework."""
|
||||||
# Warning: MUST use imgsz=640
|
# Warning: MUST use imgsz=640
|
||||||
run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args
|
run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args
|
||||||
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt")
|
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")
|
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")
|
||||||
|
@ -36,8 +36,6 @@ class RTDETR(Model):
|
|||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
||||||
"""
|
"""
|
||||||
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
|
|
||||||
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
|
|
||||||
super().__init__(model=model, task="detect")
|
super().__init__(model=model, task="detect")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
The method filters detections based on confidence and class if specified in `self.args`.
|
The method filters detections based on confidence and class if specified in `self.args`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preds (torch.Tensor): Raw predictions from the model.
|
preds (list): List of [predictions, extra] from the model.
|
||||||
img (torch.Tensor): Processed input images.
|
img (torch.Tensor): Processed input images.
|
||||||
orig_imgs (list or torch.Tensor): Original, unprocessed images.
|
orig_imgs (list or torch.Tensor): Original, unprocessed images.
|
||||||
|
|
||||||
@ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
||||||
and class labels.
|
and class labels.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||||
|
preds = [preds, None]
|
||||||
|
|
||||||
nd = preds[0].shape[-1]
|
nd = preds[0].shape[-1]
|
||||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
|
|
||||||
|
@ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||||
|
preds = [preds, None]
|
||||||
|
|
||||||
bs, _, nd = preds[0].shape
|
bs, _, nd = preds[0].shape
|
||||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
bboxes *= self.args.imgsz
|
bboxes *= self.args.imgsz
|
||||||
|
@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True):
|
|||||||
downloads.safe_download(url=url, file=file, unzip=False)
|
downloads.safe_download(url=url, file=file, unzip=False)
|
||||||
return file
|
return file
|
||||||
else: # search
|
else: # search
|
||||||
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
|
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||||
if not files and hard:
|
if not files and hard:
|
||||||
raise FileNotFoundError(f"'{file}' does not exist")
|
raise FileNotFoundError(f"'{file}' does not exist")
|
||||||
elif len(files) > 1 and hard:
|
elif len(files) > 1 and hard:
|
||||||
|
@ -145,3 +145,44 @@ def get_latest_run(search_dir="."):
|
|||||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||||
|
|
||||||
|
|
||||||
|
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
|
||||||
|
"""
|
||||||
|
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
||||||
|
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
||||||
|
update_names (bool, optional): Update model names from a data YAML.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.utils.files import update_models
|
||||||
|
|
||||||
|
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
||||||
|
update_models(model_names)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from ultralytics.nn.autobackend import default_class_names
|
||||||
|
|
||||||
|
target_dir = source_dir / "updated_models"
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
||||||
|
|
||||||
|
for model_name in model_names:
|
||||||
|
model_path = source_dir / model_name
|
||||||
|
print(f"Loading model from {model_path}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = YOLO(model_path)
|
||||||
|
model.half()
|
||||||
|
if update_names: # update model names from a dataset YAML
|
||||||
|
model.model.names = default_class_names("coco8.yaml")
|
||||||
|
|
||||||
|
# Define new save path
|
||||||
|
save_path = target_dir / model_name
|
||||||
|
|
||||||
|
# Save model using model.save()
|
||||||
|
print(f"Re-saving {model_name} model to {save_path}")
|
||||||
|
model.save(save_path, use_dill=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user