mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
RTDETRDetectionModel
TorchScript, ONNX Predict and Val support (#8818)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
911d18e4f3
commit
af6c02c39b
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -118,7 +118,7 @@ jobs:
|
||||
run: |
|
||||
yolo checks
|
||||
pip list
|
||||
- name: Benchmark World DetectionModel
|
||||
- name: Benchmark YOLOWorld DetectionModel
|
||||
shell: bash
|
||||
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
|
||||
- name: Benchmark SegmentationModel
|
||||
|
@ -65,8 +65,8 @@ def test_export(model, format):
|
||||
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
|
||||
"""Test the RTDETR functionality with the Ultralytics framework."""
|
||||
# Warning: MUST use imgsz=640
|
||||
run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args
|
||||
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt")
|
||||
run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args
|
||||
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt")
|
||||
|
||||
|
||||
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")
|
||||
|
@ -36,8 +36,6 @@ class RTDETR(Model):
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
||||
"""
|
||||
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
|
||||
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
|
||||
super().__init__(model=model, task="detect")
|
||||
|
||||
@property
|
||||
|
@ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor):
|
||||
The method filters detections based on confidence and class if specified in `self.args`.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
preds (list): List of [predictions, extra] from the model.
|
||||
img (torch.Tensor): Processed input images.
|
||||
orig_imgs (list or torch.Tensor): Original, unprocessed images.
|
||||
|
||||
@ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor):
|
||||
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
||||
and class labels.
|
||||
"""
|
||||
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||
preds = [preds, None]
|
||||
|
||||
nd = preds[0].shape[-1]
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
|
||||
|
@ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator):
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||
preds = [preds, None]
|
||||
|
||||
bs, _, nd = preds[0].shape
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
bboxes *= self.args.imgsz
|
||||
|
@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True):
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return file
|
||||
else: # search
|
||||
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
|
||||
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
|
@ -145,3 +145,44 @@ def get_latest_run(search_dir="."):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
||||
|
||||
def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
|
||||
"""
|
||||
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||
|
||||
Args:
|
||||
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
||||
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
||||
update_names (bool, optional): Update model names from a data YAML.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.files import update_models
|
||||
|
||||
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
||||
update_models(model_names)
|
||||
```
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.autobackend import default_class_names
|
||||
|
||||
target_dir = source_dir / "updated_models"
|
||||
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
||||
|
||||
for model_name in model_names:
|
||||
model_path = source_dir / model_name
|
||||
print(f"Loading model from {model_path}")
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_path)
|
||||
model.half()
|
||||
if update_names: # update model names from a dataset YAML
|
||||
model.model.names = default_class_names("coco8.yaml")
|
||||
|
||||
# Define new save path
|
||||
save_path = target_dir / model_name
|
||||
|
||||
# Save model using model.save()
|
||||
print(f"Re-saving {model_name} model to {save_path}")
|
||||
model.save(save_path, use_dill=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user