mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ulralytics 8.0.199
*.npy image loading exception handling (#5683)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: snyk-bot <snyk-bot@snyk.io> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5b3c4cfc0e
commit
cedce60f8c
20
.github/workflows/ci.yaml
vendored
20
.github/workflows/ci.yaml
vendored
@ -183,11 +183,11 @@ jobs:
|
|||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: | # CoreML must be installed before export due to protobuf error from AutoInstall
|
run: | # CoreML must be installed before export due to protobuf error from AutoInstall
|
||||||
python -m pip install --upgrade pip wheel
|
python -m pip install --upgrade pip wheel
|
||||||
|
torch=""
|
||||||
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
||||||
pip install -e . torch==1.8.0 torchvision==0.9.0 pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
torch="torch==1.8.0 torchvision==0.9.0"
|
||||||
else
|
|
||||||
pip install -e . pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
|
||||||
fi
|
fi
|
||||||
|
pip install -e . $torch pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
- name: Check environment
|
- name: Check environment
|
||||||
run: |
|
run: |
|
||||||
yolo checks
|
yolo checks
|
||||||
@ -202,7 +202,13 @@ jobs:
|
|||||||
pip list
|
pip list
|
||||||
- name: Pytest tests
|
- name: Pytest tests
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: pytest --cov=ultralytics/ --cov-report xml tests/
|
run: |
|
||||||
|
slow=""
|
||||||
|
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||||
|
pip install mlflow pycocotools 'ray[tune]'
|
||||||
|
slow="--slow"
|
||||||
|
fi
|
||||||
|
pytest $slow --cov=ultralytics/ --cov-report xml tests/
|
||||||
- name: Upload Coverage Reports to CodeCov
|
- name: Upload Coverage Reports to CodeCov
|
||||||
if: github.repository == 'ultralytics/ultralytics' # && matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
if: github.repository == 'ultralytics/ultralytics' # && matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
@ -264,10 +270,10 @@ jobs:
|
|||||||
conda config --set solver libmamba
|
conda config --set solver libmamba
|
||||||
- name: Install Ultralytics package from conda-forge
|
- name: Install Ultralytics package from conda-forge
|
||||||
run: |
|
run: |
|
||||||
conda install -c pytorch -c conda-forge pytorch torchvision ultralytics
|
conda install -c pytorch -c conda-forge pytorch torchvision ultralytics openvino
|
||||||
- name: Install pip packages
|
- name: Install pip packages
|
||||||
run: |
|
run: |
|
||||||
pip install pytest 'coremltools>=7.0' # 'openvino-dev>=2023.0'
|
pip install pytest 'coremltools>=7.0'
|
||||||
- name: Check environment
|
- name: Check environment
|
||||||
run: |
|
run: |
|
||||||
echo "RUNNER_OS is ${{ runner.os }}"
|
echo "RUNNER_OS is ${{ runner.os }}"
|
||||||
@ -297,7 +303,7 @@ jobs:
|
|||||||
- name: PyTest
|
- name: PyTest
|
||||||
run: |
|
run: |
|
||||||
git clone https://github.com/ultralytics/ultralytics
|
git clone https://github.com/ultralytics/ultralytics
|
||||||
pytest ultralytics/tests/test_cli.py # full tests fail due to openvino export failure
|
pytest ultralytics/tests
|
||||||
|
|
||||||
Summary:
|
Summary:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
|
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
|
||||||
|
# Optionally remove from local hooks with 'rm .git/hooks/pre-commit'
|
||||||
|
|
||||||
# exclude: 'docs/'
|
# exclude: 'docs/'
|
||||||
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments
|
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments
|
||||||
|
|
||||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||||
FROM ubuntu:lunar-20230615
|
FROM ubuntu:23.04
|
||||||
|
|
||||||
# Downloads to user config dir
|
# Downloads to user config dir
|
||||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||||
|
@ -8,12 +8,16 @@ keywords: Ultralytics, YOLOv8, Roboflow, vector analysis, confusion matrix, data
|
|||||||
|
|
||||||
[Roboflow](https://roboflow.com/?ref=ultralytics) has everything you need to build and deploy computer vision models. Connect Roboflow at any step in your pipeline with APIs and SDKs, or use the end-to-end interface to automate the entire process from image to inference. Whether you’re in need of [data labeling](https://roboflow.com/annotate?ref=ultralytics), [model training](https://roboflow.com/train?ref=ultralytics), or [model deployment](https://roboflow.com/deploy?ref=ultralytics), Roboflow gives you building blocks to bring custom computer vision solutions to your project.
|
[Roboflow](https://roboflow.com/?ref=ultralytics) has everything you need to build and deploy computer vision models. Connect Roboflow at any step in your pipeline with APIs and SDKs, or use the end-to-end interface to automate the entire process from image to inference. Whether you’re in need of [data labeling](https://roboflow.com/annotate?ref=ultralytics), [model training](https://roboflow.com/train?ref=ultralytics), or [model deployment](https://roboflow.com/deploy?ref=ultralytics), Roboflow gives you building blocks to bring custom computer vision solutions to your project.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
|
||||||
|
Roboflow users can use Ultralytics under the [AGPL license](https://github.com/ultralytics/ultralytics/blob/main/LICENSE) or procure an [Enterprise license](https://ultralytics.com/license) directly from Ultralytics. Be aware that Roboflow does **not** provide Ultralytics licenses, and it is the responsibility of the user to ensure appropriate licensing.
|
||||||
|
|
||||||
In this guide, we are going to showcase how to find, label, and organize data for use in training a custom Ultralytics YOLOv8 model. Use the table of contents below to jump directly to a specific section:
|
In this guide, we are going to showcase how to find, label, and organize data for use in training a custom Ultralytics YOLOv8 model. Use the table of contents below to jump directly to a specific section:
|
||||||
|
|
||||||
- Gather data for training a custom YOLOv8 model
|
- Gather data for training a custom YOLOv8 model
|
||||||
- Upload, convert and label data for YOLOv8 format
|
- Upload, convert and label data for YOLOv8 format
|
||||||
- Pre-process and augment data for model robustness
|
- Pre-process and augment data for model robustness
|
||||||
- Dataset management for YOLOv8
|
- Dataset management for [YOLOv8](https://docs.ultralytics.com/models/yolov8/)
|
||||||
- Export data in 40+ formats for model training
|
- Export data in 40+ formats for model training
|
||||||
- Upload custom YOLOv8 model weights for testing and deployment
|
- Upload custom YOLOv8 model weights for testing and deployment
|
||||||
- Gather Data for Training a Custom YOLOv8 Model
|
- Gather Data for Training a Custom YOLOv8 Model
|
||||||
|
1
setup.py
1
setup.py
@ -68,6 +68,7 @@ setup(
|
|||||||
'dev': [
|
'dev': [
|
||||||
'ipython',
|
'ipython',
|
||||||
'check-manifest',
|
'check-manifest',
|
||||||
|
'pre-commit',
|
||||||
'pytest',
|
'pytest',
|
||||||
'pytest-cov',
|
'pytest-cov',
|
||||||
'coverage',
|
'coverage',
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics import YOLO, download
|
from ultralytics import YOLO
|
||||||
from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR, checks
|
from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks
|
||||||
|
|
||||||
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
||||||
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
||||||
@ -27,6 +27,7 @@ def test_train():
|
|||||||
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
|
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
def test_predict_multiple_devices():
|
def test_predict_multiple_devices():
|
||||||
"""Validate model prediction on multiple devices."""
|
"""Validate model prediction on multiple devices."""
|
||||||
@ -102,42 +103,3 @@ def test_predict_sam():
|
|||||||
|
|
||||||
# Reset image
|
# Reset image
|
||||||
predictor.reset_image()
|
predictor.reset_image()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
|
||||||
def test_model_tune():
|
|
||||||
"""Tune YOLO model for performance."""
|
|
||||||
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
|
||||||
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
|
||||||
def test_pycocotools():
|
|
||||||
"""Validate model predictions using pycocotools."""
|
|
||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
|
||||||
from ultralytics.models.yolo.pose import PoseValidator
|
|
||||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
|
||||||
|
|
||||||
# Download annotations after each dataset downloads first
|
|
||||||
url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
|
|
||||||
|
|
||||||
args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
|
|
||||||
validator = DetectionValidator(args=args)
|
|
||||||
validator()
|
|
||||||
validator.is_coco = True
|
|
||||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
|
|
||||||
_ = validator.eval_json(validator.stats)
|
|
||||||
|
|
||||||
args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
|
|
||||||
validator = SegmentationValidator(args=args)
|
|
||||||
validator()
|
|
||||||
validator.is_coco = True
|
|
||||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
|
|
||||||
_ = validator.eval_json(validator.stats)
|
|
||||||
|
|
||||||
args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
|
|
||||||
validator = PoseValidator(args=args)
|
|
||||||
validator()
|
|
||||||
validator.is_coco = True
|
|
||||||
download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
|
|
||||||
_ = validator.eval_json(validator.stats)
|
|
||||||
|
@ -1,12 +1,20 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
import contextlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO, download
|
||||||
from ultralytics.utils import SETTINGS, checks
|
from ultralytics.utils import ASSETS, DATASETS_DIR, ROOT, SETTINGS, WEIGHTS_DIR
|
||||||
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
|
||||||
|
MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
|
||||||
|
CFG = 'yolov8n.yaml'
|
||||||
|
SOURCE = ASSETS / 'bus.jpg'
|
||||||
|
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not checks.check_requirements('ray', install=False), reason='RayTune not installed')
|
@pytest.mark.skipif(not check_requirements('ray', install=False), reason='ray[tune] not installed')
|
||||||
def test_model_ray_tune():
|
def test_model_ray_tune():
|
||||||
"""Tune YOLO model with Ray optimization library."""
|
"""Tune YOLO model with Ray optimization library."""
|
||||||
YOLO('yolov8n-cls.yaml').tune(use_ray=True,
|
YOLO('yolov8n-cls.yaml').tune(use_ray=True,
|
||||||
@ -19,8 +27,90 @@ def test_model_ray_tune():
|
|||||||
device='cpu')
|
device='cpu')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not checks.check_requirements('mlflow', install=False), reason='MLflow not installed')
|
@pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed')
|
||||||
def test_mlflow():
|
def test_mlflow():
|
||||||
"""Test training with MLflow tracking enabled."""
|
"""Test training with MLflow tracking enabled."""
|
||||||
SETTINGS['mlflow'] = True
|
SETTINGS['mlflow'] = True
|
||||||
YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu')
|
YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_requirements('tritonclient', install=False), reason='tritonclient[all] not installed')
|
||||||
|
def test_triton():
|
||||||
|
"""Test NVIDIA Triton Server functionalities."""
|
||||||
|
check_requirements('tritonclient[all]')
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tritonclient.http import InferenceServerClient # noqa
|
||||||
|
|
||||||
|
# Create variables
|
||||||
|
model_name = 'yolo'
|
||||||
|
triton_repo_path = TMP / 'triton_repo'
|
||||||
|
triton_model_path = triton_repo_path / model_name
|
||||||
|
|
||||||
|
# Export model to ONNX
|
||||||
|
f = YOLO(MODEL).export(format='onnx', dynamic=True)
|
||||||
|
|
||||||
|
# Prepare Triton repo
|
||||||
|
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(f).rename(triton_model_path / '1' / 'model.onnx')
|
||||||
|
(triton_model_path / 'config.pdtxt').touch()
|
||||||
|
|
||||||
|
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
|
||||||
|
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
|
||||||
|
|
||||||
|
# Pull the image
|
||||||
|
subprocess.call(f'docker pull {tag}', shell=True)
|
||||||
|
|
||||||
|
# Run the Triton server and capture the container ID
|
||||||
|
container_id = subprocess.check_output(
|
||||||
|
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
|
||||||
|
shell=True).decode('utf-8').strip()
|
||||||
|
|
||||||
|
# Wait for the Triton server to start
|
||||||
|
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
|
||||||
|
|
||||||
|
# Wait until model is ready
|
||||||
|
for _ in range(10):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
assert triton_client.is_model_ready(model_name)
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Check Triton inference
|
||||||
|
YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
# Kill and remove the container at the end of the test
|
||||||
|
subprocess.call(f'docker kill {container_id}', shell=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not check_requirements('pycocotools', install=False), reason='pycocotools not installed')
|
||||||
|
def test_pycocotools():
|
||||||
|
"""Validate model predictions using pycocotools."""
|
||||||
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
|
from ultralytics.models.yolo.pose import PoseValidator
|
||||||
|
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||||
|
|
||||||
|
# Download annotations after each dataset downloads first
|
||||||
|
url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
|
||||||
|
|
||||||
|
args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
|
||||||
|
validator = DetectionValidator(args=args)
|
||||||
|
validator()
|
||||||
|
validator.is_coco = True
|
||||||
|
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
|
||||||
|
_ = validator.eval_json(validator.stats)
|
||||||
|
|
||||||
|
args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
|
||||||
|
validator = SegmentationValidator(args=args)
|
||||||
|
validator()
|
||||||
|
validator.is_coco = True
|
||||||
|
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
|
||||||
|
_ = validator.eval_json(validator.stats)
|
||||||
|
|
||||||
|
args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
|
||||||
|
validator = PoseValidator(args=args)
|
||||||
|
validator()
|
||||||
|
validator.is_coco = True
|
||||||
|
download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
|
||||||
|
_ = validator.eval_json(validator.stats)
|
||||||
|
@ -495,50 +495,7 @@ def test_hub():
|
|||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
|
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
|
||||||
def test_triton():
|
def test_model_tune():
|
||||||
"""Test NVIDIA Triton Server functionalities."""
|
"""Tune YOLO model for performance."""
|
||||||
checks.check_requirements('tritonclient[all]')
|
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||||
import subprocess
|
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||||
import time
|
|
||||||
|
|
||||||
from tritonclient.http import InferenceServerClient # noqa
|
|
||||||
|
|
||||||
# Create variables
|
|
||||||
model_name = 'yolo'
|
|
||||||
triton_repo_path = TMP / 'triton_repo'
|
|
||||||
triton_model_path = triton_repo_path / model_name
|
|
||||||
|
|
||||||
# Export model to ONNX
|
|
||||||
f = YOLO(MODEL).export(format='onnx', dynamic=True)
|
|
||||||
|
|
||||||
# Prepare Triton repo
|
|
||||||
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
|
|
||||||
Path(f).rename(triton_model_path / '1' / 'model.onnx')
|
|
||||||
(triton_model_path / 'config.pdtxt').touch()
|
|
||||||
|
|
||||||
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
|
|
||||||
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
|
|
||||||
|
|
||||||
# Pull the image
|
|
||||||
subprocess.call(f'docker pull {tag}', shell=True)
|
|
||||||
|
|
||||||
# Run the Triton server and capture the container ID
|
|
||||||
container_id = subprocess.check_output(
|
|
||||||
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
|
|
||||||
shell=True).decode('utf-8').strip()
|
|
||||||
|
|
||||||
# Wait for the Triton server to start
|
|
||||||
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
|
|
||||||
|
|
||||||
# Wait until model is ready
|
|
||||||
for _ in range(10):
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
assert triton_client.is_model_ready(model_name)
|
|
||||||
break
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# Check Triton inference
|
|
||||||
YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
|
|
||||||
|
|
||||||
# Kill and remove the container at the end of the test
|
|
||||||
subprocess.call(f'docker kill {container_id}', shell=True)
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.198'
|
__version__ = '8.0.199'
|
||||||
|
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
|
@ -61,8 +61,8 @@ class BaseDataset(Dataset):
|
|||||||
single_cls=False,
|
single_cls=False,
|
||||||
classes=None,
|
classes=None,
|
||||||
fraction=1.0):
|
fraction=1.0):
|
||||||
super().__init__()
|
|
||||||
"""Initialize BaseDataset with given configuration and options."""
|
"""Initialize BaseDataset with given configuration and options."""
|
||||||
|
super().__init__()
|
||||||
self.img_path = img_path
|
self.img_path = img_path
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
@ -85,7 +85,7 @@ class BaseDataset(Dataset):
|
|||||||
self.buffer = [] # buffer size = batch size
|
self.buffer = [] # buffer size = batch size
|
||||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||||
|
|
||||||
# Cache stuff
|
# Cache images
|
||||||
if cache == 'ram' and not self.check_cache_ram():
|
if cache == 'ram' and not self.check_cache_ram():
|
||||||
cache = False
|
cache = False
|
||||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||||
@ -123,7 +123,7 @@ class BaseDataset(Dataset):
|
|||||||
return im_files
|
return im_files
|
||||||
|
|
||||||
def update_labels(self, include_class: Optional[list]):
|
def update_labels(self, include_class: Optional[list]):
|
||||||
"""include_class, filter labels to include only these classes (optional)."""
|
"""Update labels to include only these classes (optional)."""
|
||||||
include_class_array = np.array(include_class).reshape(1, -1)
|
include_class_array = np.array(include_class).reshape(1, -1)
|
||||||
for i in range(len(self.labels)):
|
for i in range(len(self.labels)):
|
||||||
if include_class is not None:
|
if include_class is not None:
|
||||||
@ -146,11 +146,17 @@ class BaseDataset(Dataset):
|
|||||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||||
if im is None: # not cached in RAM
|
if im is None: # not cached in RAM
|
||||||
if fn.exists(): # load npy
|
if fn.exists(): # load npy
|
||||||
im = np.load(fn)
|
try:
|
||||||
|
im = np.load(fn)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
|
||||||
|
Path(fn).unlink(missing_ok=True)
|
||||||
|
im = cv2.imread(f) # BGR
|
||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
if im is None:
|
if im is None:
|
||||||
raise FileNotFoundError(f'Image Not Found {f}')
|
raise FileNotFoundError(f'Image Not Found {f}')
|
||||||
|
|
||||||
h0, w0 = im.shape[:2] # orig hw
|
h0, w0 = im.shape[:2] # orig hw
|
||||||
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
||||||
r = self.imgsz / max(h0, w0) # ratio
|
r = self.imgsz / max(h0, w0) # ratio
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
"""RT-DETR model interface."""
|
"""
|
||||||
|
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
|
||||||
|
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
|
||||||
|
hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
||||||
|
|
||||||
|
For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
from ultralytics.engine.model import Model
|
from ultralytics.engine.model import Model
|
||||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||||
|
|
||||||
@ -9,17 +16,36 @@ from .val import RTDETRValidator
|
|||||||
|
|
||||||
|
|
||||||
class RTDETR(Model):
|
class RTDETR(Model):
|
||||||
"""RTDETR model interface."""
|
"""
|
||||||
|
Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
|
||||||
|
with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model='rtdetr-l.pt') -> None:
|
def __init__(self, model='rtdetr-l.pt') -> None:
|
||||||
"""Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'."""
|
"""
|
||||||
|
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
||||||
|
"""
|
||||||
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
||||||
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
|
raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
|
||||||
super().__init__(model=model, task='detect')
|
super().__init__(model=model, task='detect')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self) -> dict:
|
||||||
"""Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model."""
|
"""
|
||||||
|
Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
'detect': {
|
'detect': {
|
||||||
'predictor': RTDETRPredictor,
|
'predictor': RTDETRPredictor,
|
||||||
|
@ -10,7 +10,11 @@ from ultralytics.utils import ops
|
|||||||
|
|
||||||
class RTDETRPredictor(BasePredictor):
|
class RTDETRPredictor(BasePredictor):
|
||||||
"""
|
"""
|
||||||
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
|
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
|
||||||
|
Baidu's RT-DETR model.
|
||||||
|
|
||||||
|
This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
|
||||||
|
high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@ -21,10 +25,27 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
predictor = RTDETRPredictor(overrides=args)
|
predictor = RTDETRPredictor(overrides=args)
|
||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
imgsz (int): Image size for inference (must be square and scale-filled).
|
||||||
|
args (dict): Argument overrides for the predictor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Postprocess predictions and returns a list of Results objects."""
|
"""
|
||||||
|
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
||||||
|
|
||||||
|
The method filters detections based on confidence and class if specified in `self.args`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds (torch.Tensor): Raw predictions from the model.
|
||||||
|
img (torch.Tensor): Processed input images.
|
||||||
|
orig_imgs (list or torch.Tensor): Original, unprocessed images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
||||||
|
and class labels.
|
||||||
|
"""
|
||||||
nd = preds[0].shape[-1]
|
nd = preds[0].shape[-1]
|
||||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
|
|
||||||
@ -49,15 +70,14 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
|
|
||||||
def pre_transform(self, im):
|
def pre_transform(self, im):
|
||||||
"""
|
"""
|
||||||
Pre-transform input image before inference.
|
Pre-transforms the input images before feeding them into the model for inference. The input images are
|
||||||
|
letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
|
||||||
|
|
||||||
Notes: The size must be square(640) and scaleFilled.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(list): A list of transformed imgs.
|
(list): List of pre-transformed images ready for model inference.
|
||||||
"""
|
"""
|
||||||
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
|
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
|
||||||
return [letterbox(image=x) for x in im]
|
return [letterbox(image=x) for x in im]
|
||||||
|
@ -13,10 +13,12 @@ from .val import RTDETRDataset, RTDETRValidator
|
|||||||
|
|
||||||
class RTDETRTrainer(DetectionTrainer):
|
class RTDETRTrainer(DetectionTrainer):
|
||||||
"""
|
"""
|
||||||
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
|
Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
|
||||||
|
class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
|
||||||
|
Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
|
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
||||||
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -30,7 +32,17 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
"""Return a YOLO detection model."""
|
"""
|
||||||
|
Initialize and return an RT-DETR model for object detection tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict, optional): Model configuration. Defaults to None.
|
||||||
|
weights (str, optional): Path to pre-trained model weights. Defaults to None.
|
||||||
|
verbose (bool): Verbose logging if True. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(RTDETRDetectionModel): Initialized model.
|
||||||
|
"""
|
||||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
@ -38,31 +50,46 @@ class RTDETRTrainer(DetectionTrainer):
|
|||||||
|
|
||||||
def build_dataset(self, img_path, mode='val', batch=None):
|
def build_dataset(self, img_path, mode='val', batch=None):
|
||||||
"""
|
"""
|
||||||
Build RTDETR Dataset.
|
Build and return an RT-DETR dataset for training or validation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_path (str): Path to the folder containing images.
|
img_path (str): Path to the folder containing images.
|
||||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
mode (str): Dataset mode, either 'train' or 'val'.
|
||||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
batch (int, optional): Batch size for rectangle training. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(RTDETRDataset): Dataset object for the specific mode.
|
||||||
"""
|
"""
|
||||||
return RTDETRDataset(
|
return RTDETRDataset(img_path=img_path,
|
||||||
img_path=img_path,
|
imgsz=self.args.imgsz,
|
||||||
imgsz=self.args.imgsz,
|
batch_size=batch,
|
||||||
batch_size=batch,
|
augment=mode == 'train',
|
||||||
augment=mode == 'train', # no augmentation
|
hyp=self.args,
|
||||||
hyp=self.args,
|
rect=False,
|
||||||
rect=False, # no rect
|
cache=self.args.cache or None,
|
||||||
cache=self.args.cache or None,
|
prefix=colorstr(f'{mode}: '),
|
||||||
prefix=colorstr(f'{mode}: '),
|
data=self.data)
|
||||||
data=self.data)
|
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
"""Returns a DetectionValidator for RTDETR model validation."""
|
"""
|
||||||
|
Returns a DetectionValidator suitable for RT-DETR model validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(RTDETRValidator): Validator object for model validation.
|
||||||
|
"""
|
||||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
||||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
"""
|
||||||
|
Preprocess a batch of images. Scales and converts the images to float format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (dict): Dictionary containing a batch of images, bboxes, and labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(dict): Preprocessed batch.
|
||||||
|
"""
|
||||||
batch = super().preprocess_batch(batch)
|
batch = super().preprocess_batch(batch)
|
||||||
bs = len(batch['img'])
|
bs = len(batch['img'])
|
||||||
batch_idx = batch['batch_idx']
|
batch_idx = batch['batch_idx']
|
||||||
|
@ -3,6 +3,4 @@
|
|||||||
from .model import SAM
|
from .model import SAM
|
||||||
from .predict import Predictor
|
from .predict import Predictor
|
||||||
|
|
||||||
# from .build import build_sam
|
|
||||||
|
|
||||||
__all__ = 'SAM', 'Predictor' # tuple or list
|
__all__ = 'SAM', 'Predictor' # tuple or list
|
||||||
|
@ -1,5 +1,18 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
"""SAM model interface."""
|
"""
|
||||||
|
SAM model interface.
|
||||||
|
|
||||||
|
This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
|
||||||
|
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
|
||||||
|
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
||||||
|
image distributions and tasks without prior knowledge.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Promptable segmentation
|
||||||
|
- Real-time performance
|
||||||
|
- Zero-shot transfer capabilities
|
||||||
|
- Trained on SA-1B dataset
|
||||||
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -11,40 +24,94 @@ from .predict import Predictor
|
|||||||
|
|
||||||
|
|
||||||
class SAM(Model):
|
class SAM(Model):
|
||||||
"""SAM model interface."""
|
"""
|
||||||
|
SAM (Segment Anything Model) interface class.
|
||||||
|
|
||||||
|
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
|
||||||
|
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
|
||||||
|
dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model='sam_b.pt') -> None:
|
def __init__(self, model='sam_b.pt') -> None:
|
||||||
"""Initializes the SAM model instance with the specified pre-trained model file."""
|
"""
|
||||||
|
Initializes the SAM model with a pre-trained model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||||
|
"""
|
||||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
if model and Path(model).suffix not in ('.pt', '.pth'):
|
||||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
||||||
super().__init__(model=model, task='segment')
|
super().__init__(model=model, task='segment')
|
||||||
|
|
||||||
def _load(self, weights: str, task=None):
|
def _load(self, weights: str, task=None):
|
||||||
"""Loads the provided weights into the SAM model."""
|
"""
|
||||||
|
Loads the specified weights into the SAM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (str): Path to the weights file.
|
||||||
|
task (str, optional): Task name. Defaults to None.
|
||||||
|
"""
|
||||||
self.model = build_sam(weights)
|
self.model = build_sam(weights)
|
||||||
|
|
||||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||||
"""Predicts and returns segmentation masks for given image or video source."""
|
"""
|
||||||
|
Performs segmentation prediction on the given image or video source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||||
|
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||||
|
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||||
|
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||||
|
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The segmentation masks.
|
||||||
|
"""
|
||||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
||||||
kwargs.update(overrides)
|
kwargs.update(overrides)
|
||||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
"""
|
||||||
|
Alias for the 'predict' method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||||
|
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||||
|
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||||
|
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||||
|
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The segmentation masks.
|
||||||
|
"""
|
||||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||||
|
|
||||||
def info(self, detailed=False, verbose=True):
|
def info(self, detailed=False, verbose=True):
|
||||||
"""
|
"""
|
||||||
Logs model info.
|
Logs information about the SAM model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detailed (bool): Show detailed information about model.
|
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
|
||||||
verbose (bool): Controls verbosity.
|
verbose (bool, optional): If True, displays information on the console. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple): A tuple containing the model's information.
|
||||||
"""
|
"""
|
||||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self):
|
||||||
"""Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
|
"""
|
||||||
|
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||||
|
"""
|
||||||
return {'segment': {'predictor': Predictor}}
|
return {'segment': {'predictor': Predictor}}
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Generate predictions using the Segment Anything Model (SAM).
|
||||||
|
|
||||||
|
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
|
||||||
|
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
|
||||||
|
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
|
||||||
|
segmentation tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -18,71 +26,86 @@ from .build import build_sam
|
|||||||
|
|
||||||
class Predictor(BasePredictor):
|
class Predictor(BasePredictor):
|
||||||
"""
|
"""
|
||||||
A prediction class for segmentation tasks, extending the BasePredictor.
|
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||||
|
|
||||||
This class serves as an interface for model inference for segmentation tasks.
|
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||||
It can preprocess input images, perform inference, and postprocess the output.
|
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||||
It also supports handling various types of input prompts including bounding boxes,
|
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||||
points, and low-resolution masks for better prediction results.
|
points, and low-resolution masks.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
cfg (dict): Configuration dictionary.
|
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||||
overrides (dict): Dictionary of overriding values.
|
overrides (dict): Dictionary containing values that override the default configuration.
|
||||||
_callbacks (dict): Dictionary of callback functions.
|
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||||
args (namespace): Argument namespace.
|
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||||
im (torch.Tensor): Preprocessed image for current prediction.
|
im (torch.Tensor): Preprocessed input image tensor.
|
||||||
features (torch.Tensor): Image features.
|
features (torch.Tensor): Extracted image features used for inference.
|
||||||
prompts (dict): Dictionary of prompts like bboxes, points, masks.
|
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||||
segment_all (bool): Whether to perform segmentation on all objects or not.
|
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
|
"""
|
||||||
|
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||||
|
|
||||||
|
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||||
|
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): Configuration dictionary.
|
||||||
|
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||||
|
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||||
|
"""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
# SAM needs retina_masks=True, or the results would be a mess.
|
|
||||||
self.args.retina_masks = True
|
self.args.retina_masks = True
|
||||||
# Args for set_image
|
|
||||||
self.im = None
|
self.im = None
|
||||||
self.features = None
|
self.features = None
|
||||||
# Args for set_prompts
|
|
||||||
self.prompts = {}
|
self.prompts = {}
|
||||||
# Args for segment everything
|
|
||||||
self.segment_all = False
|
self.segment_all = False
|
||||||
|
|
||||||
def preprocess(self, im):
|
def preprocess(self, im):
|
||||||
"""
|
"""
|
||||||
Prepares input image before inference.
|
Preprocess the input image for model inference.
|
||||||
|
|
||||||
|
The method prepares the input image by applying transformations and normalization.
|
||||||
|
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The preprocessed image tensor.
|
||||||
"""
|
"""
|
||||||
if self.im is not None:
|
if self.im is not None:
|
||||||
return self.im
|
return self.im
|
||||||
not_tensor = not isinstance(im, torch.Tensor)
|
not_tensor = not isinstance(im, torch.Tensor)
|
||||||
if not_tensor:
|
if not_tensor:
|
||||||
im = np.stack(self.pre_transform(im))
|
im = np.stack(self.pre_transform(im))
|
||||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
im = im[..., ::-1].transpose((0, 3, 1, 2))
|
||||||
im = np.ascontiguousarray(im) # contiguous
|
im = np.ascontiguousarray(im)
|
||||||
im = torch.from_numpy(im)
|
im = torch.from_numpy(im)
|
||||||
|
|
||||||
im = im.to(self.device)
|
im = im.to(self.device)
|
||||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
im = im.half() if self.model.fp16 else im.float()
|
||||||
if not_tensor:
|
if not_tensor:
|
||||||
im = (im - self.mean) / self.std
|
im = (im - self.mean) / self.std
|
||||||
return im
|
return im
|
||||||
|
|
||||||
def pre_transform(self, im):
|
def pre_transform(self, im):
|
||||||
"""
|
"""
|
||||||
Pre-transform input image before inference.
|
Perform initial transformations on the input image for preprocessing.
|
||||||
|
|
||||||
|
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||||
|
Currently, batched inference is not supported; hence the list length should be 1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(list): A list of transformed images.
|
List[np.ndarray]: List of transformed images.
|
||||||
"""
|
"""
|
||||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||||
@ -90,69 +113,52 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Predict masks for the given input prompts, using the currently set image.
|
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||||
|
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||||
|
mask decoder for real-time and promptable segmentation tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||||
1 indicates a foreground point and 0 indicates a background point.
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
|
||||||
for SAM, H=W=256.
|
|
||||||
multimask_output (bool): If true, the model will return three masks.
|
|
||||||
For ambiguous input prompts (such as a single click), this will often
|
|
||||||
produce better masks than a single prediction. If only a single
|
|
||||||
mask is needed, the model's predicted quality score can be used
|
|
||||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
||||||
input prompts, multimask_output=False can give better results.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
tuple: Contains the following three elements.
|
||||||
number of masks, and (H, W) is the original image size.
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||||
(np.ndarray): An array of length C containing the model's
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||||
predictions for the quality of each mask.
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
|
||||||
of masks and H=W=256. These low resolution logits can be passed to
|
|
||||||
a subsequent iteration as mask input.
|
|
||||||
"""
|
"""
|
||||||
# Get prompts from self.prompts first
|
# Override prompts if any stored in self.prompts
|
||||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||||
points = self.prompts.pop('points', points)
|
points = self.prompts.pop('points', points)
|
||||||
masks = self.prompts.pop('masks', masks)
|
masks = self.prompts.pop('masks', masks)
|
||||||
|
|
||||||
if all(i is None for i in [bboxes, points, masks]):
|
if all(i is None for i in [bboxes, points, masks]):
|
||||||
return self.generate(im, *args, **kwargs)
|
return self.generate(im, *args, **kwargs)
|
||||||
|
|
||||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||||
|
|
||||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||||
"""
|
"""
|
||||||
Predict masks for the given input prompts, using the currently set image.
|
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||||
|
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||||
1 indicates a foreground point and 0 indicates a background point.
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
|
||||||
for SAM, H=W=256.
|
|
||||||
multimask_output (bool): If true, the model will return three masks.
|
|
||||||
For ambiguous input prompts (such as a single click), this will often
|
|
||||||
produce better masks than a single prediction. If only a single
|
|
||||||
mask is needed, the model's predicted quality score can be used
|
|
||||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
||||||
input prompts, multimask_output=False can give better results.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
tuple: Contains the following three elements.
|
||||||
number of masks, and (H, W) is the original image size.
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||||
(np.ndarray): An array of length C containing the model's
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||||
predictions for the quality of each mask.
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
|
||||||
of masks and H=W=256. These low resolution logits can be passed to
|
|
||||||
a subsequent iteration as mask input.
|
|
||||||
"""
|
"""
|
||||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||||
|
|
||||||
@ -178,11 +184,7 @@ class Predictor(BasePredictor):
|
|||||||
|
|
||||||
points = (points, labels) if points is not None else None
|
points = (points, labels) if points is not None else None
|
||||||
# Embed prompts
|
# Embed prompts
|
||||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||||
points=points,
|
|
||||||
boxes=bboxes,
|
|
||||||
masks=masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Predict masks
|
# Predict masks
|
||||||
pred_masks, pred_scores = self.model.mask_decoder(
|
pred_masks, pred_scores = self.model.mask_decoder(
|
||||||
@ -210,46 +212,35 @@ class Predictor(BasePredictor):
|
|||||||
stability_score_offset=0.95,
|
stability_score_offset=0.95,
|
||||||
crop_nms_thresh=0.7):
|
crop_nms_thresh=0.7):
|
||||||
"""
|
"""
|
||||||
Segment the whole image.
|
Perform image segmentation using the Segment Anything Model (SAM).
|
||||||
|
|
||||||
|
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||||
|
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||||
crops of the image. Sets the number of layers to run, where each
|
Each layer produces 2**i_layer number of image crops.
|
||||||
layer has 2**i_layer number of image crops.
|
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers.
|
||||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||||
In the first crop layer, crops will overlap by this fraction of
|
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||||
the image length. Later layers with more crops scale down this overlap.
|
Used in the nth crop layer.
|
||||||
crop_downscale_factor (int): The number of points-per-side
|
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
Exclusive with 'point_grids'.
|
||||||
point_grids (list(np.ndarray), None): A list over explicit grids
|
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||||
points_stride (int, None): The number of points to be sampled
|
stability_score_offset (float): Offset value for calculating stability score.
|
||||||
along one side of the image. The total number of points is
|
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops.
|
||||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
|
||||||
point sampling.
|
Returns:
|
||||||
points_batch_size (int): Sets the number of points run simultaneously
|
tuple: A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||||
by the model. Higher numbers may be faster but use more GPU memory.
|
|
||||||
conf_thres (float): A filtering threshold in [0,1], using the
|
|
||||||
model's predicted mask quality.
|
|
||||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
|
||||||
the stability of the mask under changes to the cutoff used to binarize
|
|
||||||
the model's mask predictions.
|
|
||||||
stability_score_offset (float): The amount to shift the cutoff when
|
|
||||||
calculated the stability score.
|
|
||||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
||||||
suppression to filter duplicate masks between different crops.
|
|
||||||
"""
|
"""
|
||||||
self.segment_all = True
|
self.segment_all = True
|
||||||
ih, iw = im.shape[2:]
|
ih, iw = im.shape[2:]
|
||||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||||
if point_grids is None:
|
if point_grids is None:
|
||||||
point_grids = build_all_layer_point_grids(
|
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
|
||||||
points_stride,
|
|
||||||
crop_n_layers,
|
|
||||||
crop_downscale_factor,
|
|
||||||
)
|
|
||||||
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
||||||
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
||||||
x1, y1, x2, y2 = crop_region
|
x1, y1, x2, y2 = crop_region
|
||||||
@ -312,7 +303,22 @@ class Predictor(BasePredictor):
|
|||||||
return pred_masks, pred_scores, pred_bboxes
|
return pred_masks, pred_scores, pred_bboxes
|
||||||
|
|
||||||
def setup_model(self, model, verbose=True):
|
def setup_model(self, model, verbose=True):
|
||||||
"""Set up YOLO model with specified thresholds and device."""
|
"""
|
||||||
|
Initializes the Segment Anything Model (SAM) for inference.
|
||||||
|
|
||||||
|
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
||||||
|
parameters for image normalization and other Ultralytics compatibility settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||||
|
verbose (bool): If True, prints selected device information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||||
|
device (torch.device): The device to which the model and tensors are allocated.
|
||||||
|
mean (torch.Tensor): The mean values for image normalization.
|
||||||
|
std (torch.Tensor): The standard deviation values for image normalization.
|
||||||
|
"""
|
||||||
device = select_device(self.args.device, verbose=verbose)
|
device = select_device(self.args.device, verbose=verbose)
|
||||||
if model is None:
|
if model is None:
|
||||||
model = build_sam(self.args.model)
|
model = build_sam(self.args.model)
|
||||||
@ -321,7 +327,8 @@ class Predictor(BasePredictor):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||||
# TODO: Temporary settings for compatibility
|
|
||||||
|
# Ultralytics compatibility settings
|
||||||
self.model.pt = False
|
self.model.pt = False
|
||||||
self.model.triton = False
|
self.model.triton = False
|
||||||
self.model.stride = 32
|
self.model.stride = 32
|
||||||
@ -329,7 +336,20 @@ class Predictor(BasePredictor):
|
|||||||
self.done_warmup = True
|
self.done_warmup = True
|
||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
"""
|
||||||
|
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||||
|
|
||||||
|
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The
|
||||||
|
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||||
|
img (torch.Tensor): The processed input image tensor.
|
||||||
|
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||||
|
"""
|
||||||
# (N, 1, H, W), (N, 1)
|
# (N, 1, H, W), (N, 1)
|
||||||
pred_masks, pred_scores = preds[:2]
|
pred_masks, pred_scores = preds[:2]
|
||||||
pred_bboxes = preds[2] if self.segment_all else None
|
pred_bboxes = preds[2] if self.segment_all else None
|
||||||
@ -355,15 +375,30 @@ class Predictor(BasePredictor):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def setup_source(self, source):
|
def setup_source(self, source):
|
||||||
"""Sets up source and inference mode."""
|
"""
|
||||||
|
Sets up the data source for inference.
|
||||||
|
|
||||||
|
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||||
|
directory, a video file, or other types of image data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (str | Path): The path to the image data source for inference.
|
||||||
|
"""
|
||||||
if source is not None:
|
if source is not None:
|
||||||
super().setup_source(source)
|
super().setup_source(source)
|
||||||
|
|
||||||
def set_image(self, image):
|
def set_image(self, image):
|
||||||
"""Set image in advance.
|
"""
|
||||||
Args:
|
Preprocesses and sets a single image for inference.
|
||||||
|
|
||||||
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||||
|
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If more than one image is set.
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
model = build_sam(self.args.model)
|
model = build_sam(self.args.model)
|
||||||
@ -388,17 +423,20 @@ class Predictor(BasePredictor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||||
"""
|
"""
|
||||||
Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
|
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||||
Requires open-cv as a dependency.
|
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||||
|
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (torch.Tensor): Masks, (N, H, W).
|
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||||
min_area (int): Minimum area threshold.
|
the number of masks, H is height, and W is width.
|
||||||
nms_thresh (float): NMS threshold.
|
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||||
|
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
new_masks (torch.Tensor): New Masks, (N, H, W).
|
T(uple[torch.Tensor, List[int]]):
|
||||||
keep (List[int]): The indices of the new masks, which can be used to filter
|
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||||
the corresponding boxes.
|
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||||
"""
|
"""
|
||||||
if len(masks) == 0:
|
if len(masks) == 0:
|
||||||
return masks
|
return masks
|
||||||
@ -420,10 +458,6 @@ class Predictor(BasePredictor):
|
|||||||
# Recalculate boxes and remove any new duplicates
|
# Recalculate boxes and remove any new duplicates
|
||||||
new_masks = torch.cat(new_masks, dim=0)
|
new_masks = torch.cat(new_masks, dim=0)
|
||||||
boxes = batched_mask_to_box(new_masks)
|
boxes = batched_mask_to_box(new_masks)
|
||||||
keep = torchvision.ops.nms(
|
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||||
boxes.float(),
|
|
||||||
torch.as_tensor(scores),
|
|
||||||
nms_thresh,
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||||
|
Loading…
x
Reference in New Issue
Block a user