mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ulralytics 8.0.199
*.npy image loading exception handling (#5683)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: snyk-bot <snyk-bot@snyk.io> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5b3c4cfc0e
commit
cedce60f8c
20
.github/workflows/ci.yaml
vendored
20
.github/workflows/ci.yaml
vendored
@ -183,11 +183,11 @@ jobs:
|
||||
shell: bash # for Windows compatibility
|
||||
run: | # CoreML must be installed before export due to protobuf error from AutoInstall
|
||||
python -m pip install --upgrade pip wheel
|
||||
torch=""
|
||||
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
||||
pip install -e . torch==1.8.0 torchvision==0.9.0 pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
else
|
||||
pip install -e . pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch="torch==1.8.0 torchvision==0.9.0"
|
||||
fi
|
||||
pip install -e . $torch pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Check environment
|
||||
run: |
|
||||
yolo checks
|
||||
@ -202,7 +202,13 @@ jobs:
|
||||
pip list
|
||||
- name: Pytest tests
|
||||
shell: bash # for Windows compatibility
|
||||
run: pytest --cov=ultralytics/ --cov-report xml tests/
|
||||
run: |
|
||||
slow=""
|
||||
if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
pip install mlflow pycocotools 'ray[tune]'
|
||||
slow="--slow"
|
||||
fi
|
||||
pytest $slow --cov=ultralytics/ --cov-report xml tests/
|
||||
- name: Upload Coverage Reports to CodeCov
|
||||
if: github.repository == 'ultralytics/ultralytics' # && matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
||||
uses: codecov/codecov-action@v3
|
||||
@ -264,10 +270,10 @@ jobs:
|
||||
conda config --set solver libmamba
|
||||
- name: Install Ultralytics package from conda-forge
|
||||
run: |
|
||||
conda install -c pytorch -c conda-forge pytorch torchvision ultralytics
|
||||
conda install -c pytorch -c conda-forge pytorch torchvision ultralytics openvino
|
||||
- name: Install pip packages
|
||||
run: |
|
||||
pip install pytest 'coremltools>=7.0' # 'openvino-dev>=2023.0'
|
||||
pip install pytest 'coremltools>=7.0'
|
||||
- name: Check environment
|
||||
run: |
|
||||
echo "RUNNER_OS is ${{ runner.os }}"
|
||||
@ -297,7 +303,7 @@ jobs:
|
||||
- name: PyTest
|
||||
run: |
|
||||
git clone https://github.com/ultralytics/ultralytics
|
||||
pytest ultralytics/tests/test_cli.py # full tests fail due to openvino export failure
|
||||
pytest ultralytics/tests
|
||||
|
||||
Summary:
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
|
||||
# Optionally remove from local hooks with 'rm .git/hooks/pre-commit'
|
||||
|
||||
# exclude: 'docs/'
|
||||
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments
|
||||
|
||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||
FROM ubuntu:lunar-20230615
|
||||
FROM ubuntu:23.04
|
||||
|
||||
# Downloads to user config dir
|
||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||
|
@ -8,12 +8,16 @@ keywords: Ultralytics, YOLOv8, Roboflow, vector analysis, confusion matrix, data
|
||||
|
||||
[Roboflow](https://roboflow.com/?ref=ultralytics) has everything you need to build and deploy computer vision models. Connect Roboflow at any step in your pipeline with APIs and SDKs, or use the end-to-end interface to automate the entire process from image to inference. Whether you’re in need of [data labeling](https://roboflow.com/annotate?ref=ultralytics), [model training](https://roboflow.com/train?ref=ultralytics), or [model deployment](https://roboflow.com/deploy?ref=ultralytics), Roboflow gives you building blocks to bring custom computer vision solutions to your project.
|
||||
|
||||
!!! warning
|
||||
|
||||
Roboflow users can use Ultralytics under the [AGPL license](https://github.com/ultralytics/ultralytics/blob/main/LICENSE) or procure an [Enterprise license](https://ultralytics.com/license) directly from Ultralytics. Be aware that Roboflow does **not** provide Ultralytics licenses, and it is the responsibility of the user to ensure appropriate licensing.
|
||||
|
||||
In this guide, we are going to showcase how to find, label, and organize data for use in training a custom Ultralytics YOLOv8 model. Use the table of contents below to jump directly to a specific section:
|
||||
|
||||
- Gather data for training a custom YOLOv8 model
|
||||
- Upload, convert and label data for YOLOv8 format
|
||||
- Pre-process and augment data for model robustness
|
||||
- Dataset management for YOLOv8
|
||||
- Dataset management for [YOLOv8](https://docs.ultralytics.com/models/yolov8/)
|
||||
- Export data in 40+ formats for model training
|
||||
- Upload custom YOLOv8 model weights for testing and deployment
|
||||
- Gather Data for Training a Custom YOLOv8 Model
|
||||
|
1
setup.py
1
setup.py
@ -68,6 +68,7 @@ setup(
|
||||
'dev': [
|
||||
'ipython',
|
||||
'check-manifest',
|
||||
'pre-commit',
|
||||
'pytest',
|
||||
'pytest-cov',
|
||||
'coverage',
|
||||
|
@ -3,8 +3,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO, download
|
||||
from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR, checks
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks
|
||||
|
||||
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
||||
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
||||
@ -27,6 +27,7 @@ def test_train():
|
||||
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||
def test_predict_multiple_devices():
|
||||
"""Validate model prediction on multiple devices."""
|
||||
@ -102,42 +103,3 @@ def test_predict_sam():
|
||||
|
||||
# Reset image
|
||||
predictor.reset_image()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||
def test_model_tune():
|
||||
"""Tune YOLO model for performance."""
|
||||
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||
def test_pycocotools():
|
||||
"""Validate model predictions using pycocotools."""
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.pose import PoseValidator
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
|
||||
# Download annotations after each dataset downloads first
|
||||
url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
|
||||
|
||||
args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = DetectionValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
||||
args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = SegmentationValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
||||
args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = PoseValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
@ -1,12 +1,20 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils import SETTINGS, checks
|
||||
from ultralytics import YOLO, download
|
||||
from ultralytics.utils import ASSETS, DATASETS_DIR, ROOT, SETTINGS, WEIGHTS_DIR
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
|
||||
CFG = 'yolov8n.yaml'
|
||||
SOURCE = ASSETS / 'bus.jpg'
|
||||
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
|
||||
|
||||
|
||||
@pytest.mark.skipif(not checks.check_requirements('ray', install=False), reason='RayTune not installed')
|
||||
@pytest.mark.skipif(not check_requirements('ray', install=False), reason='ray[tune] not installed')
|
||||
def test_model_ray_tune():
|
||||
"""Tune YOLO model with Ray optimization library."""
|
||||
YOLO('yolov8n-cls.yaml').tune(use_ray=True,
|
||||
@ -19,8 +27,90 @@ def test_model_ray_tune():
|
||||
device='cpu')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not checks.check_requirements('mlflow', install=False), reason='MLflow not installed')
|
||||
@pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed')
|
||||
def test_mlflow():
|
||||
"""Test training with MLflow tracking enabled."""
|
||||
SETTINGS['mlflow'] = True
|
||||
YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not check_requirements('tritonclient', install=False), reason='tritonclient[all] not installed')
|
||||
def test_triton():
|
||||
"""Test NVIDIA Triton Server functionalities."""
|
||||
check_requirements('tritonclient[all]')
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from tritonclient.http import InferenceServerClient # noqa
|
||||
|
||||
# Create variables
|
||||
model_name = 'yolo'
|
||||
triton_repo_path = TMP / 'triton_repo'
|
||||
triton_model_path = triton_repo_path / model_name
|
||||
|
||||
# Export model to ONNX
|
||||
f = YOLO(MODEL).export(format='onnx', dynamic=True)
|
||||
|
||||
# Prepare Triton repo
|
||||
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
|
||||
Path(f).rename(triton_model_path / '1' / 'model.onnx')
|
||||
(triton_model_path / 'config.pdtxt').touch()
|
||||
|
||||
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
|
||||
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
|
||||
|
||||
# Pull the image
|
||||
subprocess.call(f'docker pull {tag}', shell=True)
|
||||
|
||||
# Run the Triton server and capture the container ID
|
||||
container_id = subprocess.check_output(
|
||||
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
|
||||
shell=True).decode('utf-8').strip()
|
||||
|
||||
# Wait for the Triton server to start
|
||||
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
|
||||
|
||||
# Wait until model is ready
|
||||
for _ in range(10):
|
||||
with contextlib.suppress(Exception):
|
||||
assert triton_client.is_model_ready(model_name)
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
# Check Triton inference
|
||||
YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
|
||||
|
||||
# Kill and remove the container at the end of the test
|
||||
subprocess.call(f'docker kill {container_id}', shell=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not check_requirements('pycocotools', install=False), reason='pycocotools not installed')
|
||||
def test_pycocotools():
|
||||
"""Validate model predictions using pycocotools."""
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.pose import PoseValidator
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
|
||||
# Download annotations after each dataset downloads first
|
||||
url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
|
||||
|
||||
args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = DetectionValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
||||
args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = SegmentationValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
||||
args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
|
||||
validator = PoseValidator(args=args)
|
||||
validator()
|
||||
validator.is_coco = True
|
||||
download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
|
||||
_ = validator.eval_json(validator.stats)
|
||||
|
@ -495,50 +495,7 @@ def test_hub():
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
|
||||
def test_triton():
|
||||
"""Test NVIDIA Triton Server functionalities."""
|
||||
checks.check_requirements('tritonclient[all]')
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from tritonclient.http import InferenceServerClient # noqa
|
||||
|
||||
# Create variables
|
||||
model_name = 'yolo'
|
||||
triton_repo_path = TMP / 'triton_repo'
|
||||
triton_model_path = triton_repo_path / model_name
|
||||
|
||||
# Export model to ONNX
|
||||
f = YOLO(MODEL).export(format='onnx', dynamic=True)
|
||||
|
||||
# Prepare Triton repo
|
||||
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
|
||||
Path(f).rename(triton_model_path / '1' / 'model.onnx')
|
||||
(triton_model_path / 'config.pdtxt').touch()
|
||||
|
||||
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
|
||||
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
|
||||
|
||||
# Pull the image
|
||||
subprocess.call(f'docker pull {tag}', shell=True)
|
||||
|
||||
# Run the Triton server and capture the container ID
|
||||
container_id = subprocess.check_output(
|
||||
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
|
||||
shell=True).decode('utf-8').strip()
|
||||
|
||||
# Wait for the Triton server to start
|
||||
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
|
||||
|
||||
# Wait until model is ready
|
||||
for _ in range(10):
|
||||
with contextlib.suppress(Exception):
|
||||
assert triton_client.is_model_ready(model_name)
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
# Check Triton inference
|
||||
YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
|
||||
|
||||
# Kill and remove the container at the end of the test
|
||||
subprocess.call(f'docker kill {container_id}', shell=True)
|
||||
def test_model_tune():
|
||||
"""Tune YOLO model for performance."""
|
||||
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.198'
|
||||
__version__ = '8.0.199'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
@ -61,8 +61,8 @@ class BaseDataset(Dataset):
|
||||
single_cls=False,
|
||||
classes=None,
|
||||
fraction=1.0):
|
||||
super().__init__()
|
||||
"""Initialize BaseDataset with given configuration and options."""
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
self.imgsz = imgsz
|
||||
self.augment = augment
|
||||
@ -85,7 +85,7 @@ class BaseDataset(Dataset):
|
||||
self.buffer = [] # buffer size = batch size
|
||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||
|
||||
# Cache stuff
|
||||
# Cache images
|
||||
if cache == 'ram' and not self.check_cache_ram():
|
||||
cache = False
|
||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||
@ -123,7 +123,7 @@ class BaseDataset(Dataset):
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
"""include_class, filter labels to include only these classes (optional)."""
|
||||
"""Update labels to include only these classes (optional)."""
|
||||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class is not None:
|
||||
@ -146,11 +146,17 @@ class BaseDataset(Dataset):
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
try:
|
||||
im = np.load(fn)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
|
||||
Path(fn).unlink(missing_ok=True)
|
||||
im = cv2.imread(f) # BGR
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
||||
r = self.imgsz / max(h0, w0) # ratio
|
||||
|
@ -1,5 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""RT-DETR model interface."""
|
||||
"""
|
||||
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
|
||||
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
|
||||
hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
||||
|
||||
For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
|
||||
"""
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
|
||||
@ -9,17 +16,36 @@ from .val import RTDETRValidator
|
||||
|
||||
|
||||
class RTDETR(Model):
|
||||
"""RTDETR model interface."""
|
||||
"""
|
||||
Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
|
||||
with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
|
||||
|
||||
Attributes:
|
||||
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
||||
"""
|
||||
|
||||
def __init__(self, model='rtdetr-l.pt') -> None:
|
||||
"""Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'."""
|
||||
"""
|
||||
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
||||
"""
|
||||
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
||||
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
|
||||
raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
|
||||
super().__init__(model=model, task='detect')
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model."""
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
||||
"""
|
||||
return {
|
||||
'detect': {
|
||||
'predictor': RTDETRPredictor,
|
||||
|
@ -10,7 +10,11 @@ from ultralytics.utils import ops
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on an RT-DETR detection model.
|
||||
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
|
||||
Baidu's RT-DETR model.
|
||||
|
||||
This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
|
||||
high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -21,10 +25,27 @@ class RTDETRPredictor(BasePredictor):
|
||||
predictor = RTDETRPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
```
|
||||
|
||||
Attributes:
|
||||
imgsz (int): Image size for inference (must be square and scale-filled).
|
||||
args (dict): Argument overrides for the predictor.
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
"""
|
||||
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
||||
|
||||
The method filters detections based on confidence and class if specified in `self.args`.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Processed input images.
|
||||
orig_imgs (list or torch.Tensor): Original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
|
||||
and class labels.
|
||||
"""
|
||||
nd = preds[0].shape[-1]
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
|
||||
@ -49,15 +70,14 @@ class RTDETRPredictor(BasePredictor):
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
Pre-transforms the input images before feeding them into the model for inference. The input images are
|
||||
letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Notes: The size must be square(640) and scaleFilled.
|
||||
im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed imgs.
|
||||
(list): List of pre-transformed images ready for model inference.
|
||||
"""
|
||||
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
@ -13,10 +13,12 @@ from .val import RTDETRDataset, RTDETRValidator
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an RT-DETR detection model.
|
||||
Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
|
||||
class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
|
||||
Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
|
||||
|
||||
Notes:
|
||||
- F.grid_sample used in rt-detr does not support the `deterministic=True` argument.
|
||||
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
||||
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
||||
|
||||
Example:
|
||||
@ -30,7 +32,17 @@ class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
"""
|
||||
Initialize and return an RT-DETR model for object detection tasks.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Model configuration. Defaults to None.
|
||||
weights (str, optional): Path to pre-trained model weights. Defaults to None.
|
||||
verbose (bool): Verbose logging if True. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(RTDETRDetectionModel): Initialized model.
|
||||
"""
|
||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -38,31 +50,46 @@ class RTDETRTrainer(DetectionTrainer):
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""
|
||||
Build RTDETR Dataset.
|
||||
Build and return an RT-DETR dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
mode (str): Dataset mode, either 'train' or 'val'.
|
||||
batch (int, optional): Batch size for rectangle training. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(RTDETRDataset): Dataset object for the specific mode.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train', # no augmentation
|
||||
hyp=self.args,
|
||||
rect=False, # no rect
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
data=self.data)
|
||||
return RTDETRDataset(img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train',
|
||||
hyp=self.args,
|
||||
rect=False,
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
data=self.data)
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for RTDETR model validation."""
|
||||
"""
|
||||
Returns a DetectionValidator suitable for RT-DETR model validation.
|
||||
|
||||
Returns:
|
||||
(RTDETRValidator): Validator object for model validation.
|
||||
"""
|
||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
"""
|
||||
Preprocess a batch of images. Scales and converts the images to float format.
|
||||
|
||||
Args:
|
||||
batch (dict): Dictionary containing a batch of images, bboxes, and labels.
|
||||
|
||||
Returns:
|
||||
(dict): Preprocessed batch.
|
||||
"""
|
||||
batch = super().preprocess_batch(batch)
|
||||
bs = len(batch['img'])
|
||||
batch_idx = batch['batch_idx']
|
||||
|
@ -3,6 +3,4 @@
|
||||
from .model import SAM
|
||||
from .predict import Predictor
|
||||
|
||||
# from .build import build_sam
|
||||
|
||||
__all__ = 'SAM', 'Predictor' # tuple or list
|
||||
|
@ -1,5 +1,18 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""SAM model interface."""
|
||||
"""
|
||||
SAM model interface.
|
||||
|
||||
This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
|
||||
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
|
||||
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
||||
image distributions and tasks without prior knowledge.
|
||||
|
||||
Key Features:
|
||||
- Promptable segmentation
|
||||
- Real-time performance
|
||||
- Zero-shot transfer capabilities
|
||||
- Trained on SA-1B dataset
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@ -11,40 +24,94 @@ from .predict import Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
"""SAM model interface."""
|
||||
"""
|
||||
SAM (Segment Anything Model) interface class.
|
||||
|
||||
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
|
||||
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
|
||||
dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, model='sam_b.pt') -> None:
|
||||
"""Initializes the SAM model instance with the specified pre-trained model file."""
|
||||
"""
|
||||
Initializes the SAM model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
"""
|
||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
||||
super().__init__(model=model, task='segment')
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""Loads the provided weights into the SAM model."""
|
||||
"""
|
||||
Loads the specified weights into the SAM model.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file.
|
||||
task (str, optional): Task name. Defaults to None.
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Predicts and returns segmentation masks for given image or video source."""
|
||||
"""
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The segmentation masks.
|
||||
"""
|
||||
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
|
||||
kwargs.update(overrides)
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
"""
|
||||
Alias for the 'predict' method.
|
||||
|
||||
Args:
|
||||
source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The segmentation masks.
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
Logs information about the SAM model.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
|
||||
verbose (bool, optional): If True, displays information on the console. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the model's information.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
|
||||
"""
|
||||
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||
"""
|
||||
return {'segment': {'predictor': Predictor}}
|
||||
|
@ -1,4 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Generate predictions using the Segment Anything Model (SAM).
|
||||
|
||||
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
|
||||
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
|
||||
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
|
||||
segmentation tasks.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -18,71 +26,86 @@ from .build import build_sam
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
A prediction class for segmentation tasks, extending the BasePredictor.
|
||||
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||
|
||||
This class serves as an interface for model inference for segmentation tasks.
|
||||
It can preprocess input images, perform inference, and postprocess the output.
|
||||
It also supports handling various types of input prompts including bounding boxes,
|
||||
points, and low-resolution masks for better prediction results.
|
||||
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||
points, and low-resolution masks.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict): Dictionary of overriding values.
|
||||
_callbacks (dict): Dictionary of callback functions.
|
||||
args (namespace): Argument namespace.
|
||||
im (torch.Tensor): Preprocessed image for current prediction.
|
||||
features (torch.Tensor): Image features.
|
||||
prompts (dict): Dictionary of prompts like bboxes, points, masks.
|
||||
segment_all (bool): Whether to perform segmentation on all objects or not.
|
||||
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
# SAM needs retina_masks=True, or the results would be a mess.
|
||||
self.args.retina_masks = True
|
||||
# Args for set_image
|
||||
self.im = None
|
||||
self.features = None
|
||||
# Args for set_prompts
|
||||
self.prompts = {}
|
||||
# Args for segment everything
|
||||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""
|
||||
Prepares input image before inference.
|
||||
Preprocess the input image for model inference.
|
||||
|
||||
The method prepares the input image by applying transformations and normalization.
|
||||
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The preprocessed image tensor.
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2))
|
||||
im = np.ascontiguousarray(im)
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
im = im.half() if self.model.fp16 else im.float()
|
||||
if not_tensor:
|
||||
im = (im - self.mean) / self.std
|
||||
return im
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
Currently, batched inference is not supported; hence the list length should be 1.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||
|
||||
Returns:
|
||||
(list): A list of transformed images.
|
||||
List[np.ndarray]: List of transformed images.
|
||||
"""
|
||||
assert len(im) == 1, 'SAM model does not currently support batched inference'
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
@ -90,69 +113,52 @@ class Predictor(BasePredictor):
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
tuple: Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
# Get prompts from self.prompts first
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop('bboxes', bboxes)
|
||||
points = self.prompts.pop('points', points)
|
||||
masks = self.prompts.pop('masks', masks)
|
||||
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
||||
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
||||
1 indicates a foreground point and 0 indicates a background point.
|
||||
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form (N, H, W), where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
tuple: Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
|
||||
@ -178,11 +184,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=bboxes,
|
||||
masks=masks,
|
||||
)
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
@ -210,46 +212,35 @@ class Predictor(BasePredictor):
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7):
|
||||
"""
|
||||
Segment the whole image.
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray), None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
points_stride (int, None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_batch_size (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
conf_thres (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||
Each layer produces 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||
Used in the nth crop layer.
|
||||
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||
Exclusive with 'point_grids'.
|
||||
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||
stability_score_offset (float): Offset value for calculating stability score.
|
||||
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
"""
|
||||
self.segment_all = True
|
||||
ih, iw = im.shape[2:]
|
||||
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
||||
if point_grids is None:
|
||||
point_grids = build_all_layer_point_grids(
|
||||
points_stride,
|
||||
crop_n_layers,
|
||||
crop_downscale_factor,
|
||||
)
|
||||
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
|
||||
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
||||
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
@ -312,7 +303,22 @@ class Predictor(BasePredictor):
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Set up YOLO model with specified thresholds and device."""
|
||||
"""
|
||||
Initializes the Segment Anything Model (SAM) for inference.
|
||||
|
||||
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
||||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||
device (torch.device): The device to which the model and tensors are allocated.
|
||||
mean (torch.Tensor): The mean values for image normalization.
|
||||
std (torch.Tensor): The standard deviation values for image normalization.
|
||||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
@ -321,7 +327,8 @@ class Predictor(BasePredictor):
|
||||
self.device = device
|
||||
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
||||
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
||||
# TODO: Temporary settings for compatibility
|
||||
|
||||
# Ultralytics compatibility settings
|
||||
self.model.pt = False
|
||||
self.model.triton = False
|
||||
self.model.stride = 32
|
||||
@ -329,7 +336,20 @@ class Predictor(BasePredictor):
|
||||
self.done_warmup = True
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The
|
||||
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||
|
||||
Args:
|
||||
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||
img (torch.Tensor): The processed input image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||
"""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
@ -355,15 +375,30 @@ class Predictor(BasePredictor):
|
||||
return results
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
"""
|
||||
Sets up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||
directory, a video file, or other types of image data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path): The path to the image data source for inference.
|
||||
"""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
||||
def set_image(self, image):
|
||||
"""Set image in advance.
|
||||
Args:
|
||||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
@ -388,17 +423,20 @@ class Predictor(BasePredictor):
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
|
||||
Requires open-cv as a dependency.
|
||||
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks, (N, H, W).
|
||||
min_area (int): Minimum area threshold.
|
||||
nms_thresh (float): NMS threshold.
|
||||
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||
the number of masks, H is height, and W is width.
|
||||
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
new_masks (torch.Tensor): New Masks, (N, H, W).
|
||||
keep (List[int]): The indices of the new masks, which can be used to filter
|
||||
the corresponding boxes.
|
||||
T(uple[torch.Tensor, List[int]]):
|
||||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
"""
|
||||
if len(masks) == 0:
|
||||
return masks
|
||||
@ -420,10 +458,6 @@ class Predictor(BasePredictor):
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
new_masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(new_masks)
|
||||
keep = torchvision.ops.nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
nms_thresh,
|
||||
)
|
||||
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||
|
Loading…
x
Reference in New Issue
Block a user