From cedce60f8c43b83581f4e87ab1bcef7096d9d72c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Oct 2023 18:24:06 +0200 Subject: [PATCH] `ulralytics 8.0.199` *.npy image loading exception handling (#5683) Signed-off-by: Glenn Jocher Co-authored-by: snyk-bot Co-authored-by: Yonghye Kwon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 20 +- .pre-commit-config.yaml | 1 + docker/Dockerfile-cpu | 2 +- docs/integrations/roboflow.md | 6 +- setup.py | 1 + tests/test_cuda.py | 44 +--- tests/test_integrations.py | 98 ++++++++- tests/test_python.py | 51 +---- ultralytics/__init__.py | 2 +- ultralytics/data/base.py | 18 +- ultralytics/models/rtdetr/model.py | 38 +++- ultralytics/models/rtdetr/predict.py | 34 +++- ultralytics/models/rtdetr/train.py | 63 ++++-- ultralytics/models/sam/__init__.py | 2 - ultralytics/models/sam/model.py | 87 +++++++- ultralytics/models/sam/predict.py | 292 +++++++++++++++------------ 16 files changed, 479 insertions(+), 280 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0d6d8959..30d36ff6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -183,11 +183,11 @@ jobs: shell: bash # for Windows compatibility run: | # CoreML must be installed before export due to protobuf error from AutoInstall python -m pip install --upgrade pip wheel + torch="" if [ "${{ matrix.torch }}" == "1.8.0" ]; then - pip install -e . torch==1.8.0 torchvision==0.9.0 pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu - else - pip install -e . pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu + torch="torch==1.8.0 torchvision==0.9.0" fi + pip install -e . $torch pytest-cov "coremltools>=7.0" --extra-index-url https://download.pytorch.org/whl/cpu - name: Check environment run: | yolo checks @@ -202,7 +202,13 @@ jobs: pip list - name: Pytest tests shell: bash # for Windows compatibility - run: pytest --cov=ultralytics/ --cov-report xml tests/ + run: | + slow="" + if [[ "${{ github.event_name }}" == "schedule" ]] || [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + pip install mlflow pycocotools 'ray[tune]' + slow="--slow" + fi + pytest $slow --cov=ultralytics/ --cov-report xml tests/ - name: Upload Coverage Reports to CodeCov if: github.repository == 'ultralytics/ultralytics' # && matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' uses: codecov/codecov-action@v3 @@ -264,10 +270,10 @@ jobs: conda config --set solver libmamba - name: Install Ultralytics package from conda-forge run: | - conda install -c pytorch -c conda-forge pytorch torchvision ultralytics + conda install -c pytorch -c conda-forge pytorch torchvision ultralytics openvino - name: Install pip packages run: | - pip install pytest 'coremltools>=7.0' # 'openvino-dev>=2023.0' + pip install pytest 'coremltools>=7.0' - name: Check environment run: | echo "RUNNER_OS is ${{ runner.os }}" @@ -297,7 +303,7 @@ jobs: - name: PyTest run: | git clone https://github.com/ultralytics/ultralytics - pytest ultralytics/tests/test_cli.py # full tests fail due to openvino export failure + pytest ultralytics/tests Summary: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a75167bc..061d96c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md +# Optionally remove from local hooks with 'rm .git/hooks/pre-commit' # exclude: 'docs/' # Define bot property if installed via https://github.com/marketplace/pre-commit-ci diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index cdce1077..cc528c86 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -3,7 +3,7 @@ # Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments # Start FROM Ubuntu image https://hub.docker.com/_/ubuntu -FROM ubuntu:lunar-20230615 +FROM ubuntu:23.04 # Downloads to user config dir ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ diff --git a/docs/integrations/roboflow.md b/docs/integrations/roboflow.md index 1d1e9e39..ed6b8cd3 100644 --- a/docs/integrations/roboflow.md +++ b/docs/integrations/roboflow.md @@ -8,12 +8,16 @@ keywords: Ultralytics, YOLOv8, Roboflow, vector analysis, confusion matrix, data [Roboflow](https://roboflow.com/?ref=ultralytics) has everything you need to build and deploy computer vision models. Connect Roboflow at any step in your pipeline with APIs and SDKs, or use the end-to-end interface to automate the entire process from image to inference. Whether you’re in need of [data labeling](https://roboflow.com/annotate?ref=ultralytics), [model training](https://roboflow.com/train?ref=ultralytics), or [model deployment](https://roboflow.com/deploy?ref=ultralytics), Roboflow gives you building blocks to bring custom computer vision solutions to your project. +!!! warning + + Roboflow users can use Ultralytics under the [AGPL license](https://github.com/ultralytics/ultralytics/blob/main/LICENSE) or procure an [Enterprise license](https://ultralytics.com/license) directly from Ultralytics. Be aware that Roboflow does **not** provide Ultralytics licenses, and it is the responsibility of the user to ensure appropriate licensing. + In this guide, we are going to showcase how to find, label, and organize data for use in training a custom Ultralytics YOLOv8 model. Use the table of contents below to jump directly to a specific section: - Gather data for training a custom YOLOv8 model - Upload, convert and label data for YOLOv8 format - Pre-process and augment data for model robustness -- Dataset management for YOLOv8 +- Dataset management for [YOLOv8](https://docs.ultralytics.com/models/yolov8/) - Export data in 40+ formats for model training - Upload custom YOLOv8 model weights for testing and deployment - Gather Data for Training a Custom YOLOv8 Model diff --git a/setup.py b/setup.py index 8fb107c8..3b8d4643 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ setup( 'dev': [ 'ipython', 'check-manifest', + 'pre-commit', 'pytest', 'pytest-cov', 'coverage', diff --git a/tests/test_cuda.py b/tests/test_cuda.py index fe4e2da0..1ef1104f 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -3,8 +3,8 @@ import pytest import torch -from ultralytics import YOLO, download -from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR, checks +from ultralytics import YOLO +from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks CUDA_IS_AVAILABLE = checks.cuda_is_available() CUDA_DEVICE_COUNT = checks.cuda_device_count() @@ -27,6 +27,7 @@ def test_train(): YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64 +@pytest.mark.slow @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_predict_multiple_devices(): """Validate model prediction on multiple devices.""" @@ -102,42 +103,3 @@ def test_predict_sam(): # Reset image predictor.reset_image() - - -@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') -def test_model_tune(): - """Tune YOLO model for performance.""" - YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') - YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') - - -@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') -def test_pycocotools(): - """Validate model predictions using pycocotools.""" - from ultralytics.models.yolo.detect import DetectionValidator - from ultralytics.models.yolo.pose import PoseValidator - from ultralytics.models.yolo.segment import SegmentationValidator - - # Download annotations after each dataset downloads first - url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' - - args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64} - validator = DetectionValidator(args=args) - validator() - validator.is_coco = True - download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations') - _ = validator.eval_json(validator.stats) - - args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64} - validator = SegmentationValidator(args=args) - validator() - validator.is_coco = True - download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations') - _ = validator.eval_json(validator.stats) - - args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64} - validator = PoseValidator(args=args) - validator() - validator.is_coco = True - download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations') - _ = validator.eval_json(validator.stats) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 07062967..1911594b 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -1,12 +1,20 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license +import contextlib +from pathlib import Path import pytest -from ultralytics import YOLO -from ultralytics.utils import SETTINGS, checks +from ultralytics import YOLO, download +from ultralytics.utils import ASSETS, DATASETS_DIR, ROOT, SETTINGS, WEIGHTS_DIR +from ultralytics.utils.checks import check_requirements + +MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path +CFG = 'yolov8n.yaml' +SOURCE = ASSETS / 'bus.jpg' +TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files -@pytest.mark.skipif(not checks.check_requirements('ray', install=False), reason='RayTune not installed') +@pytest.mark.skipif(not check_requirements('ray', install=False), reason='ray[tune] not installed') def test_model_ray_tune(): """Tune YOLO model with Ray optimization library.""" YOLO('yolov8n-cls.yaml').tune(use_ray=True, @@ -19,8 +27,90 @@ def test_model_ray_tune(): device='cpu') -@pytest.mark.skipif(not checks.check_requirements('mlflow', install=False), reason='MLflow not installed') +@pytest.mark.skipif(not check_requirements('mlflow', install=False), reason='mlflow not installed') def test_mlflow(): """Test training with MLflow tracking enabled.""" SETTINGS['mlflow'] = True YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu') + + +@pytest.mark.skipif(not check_requirements('tritonclient', install=False), reason='tritonclient[all] not installed') +def test_triton(): + """Test NVIDIA Triton Server functionalities.""" + check_requirements('tritonclient[all]') + import subprocess + import time + + from tritonclient.http import InferenceServerClient # noqa + + # Create variables + model_name = 'yolo' + triton_repo_path = TMP / 'triton_repo' + triton_model_path = triton_repo_path / model_name + + # Export model to ONNX + f = YOLO(MODEL).export(format='onnx', dynamic=True) + + # Prepare Triton repo + (triton_model_path / '1').mkdir(parents=True, exist_ok=True) + Path(f).rename(triton_model_path / '1' / 'model.onnx') + (triton_model_path / 'config.pdtxt').touch() + + # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver + tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB + + # Pull the image + subprocess.call(f'docker pull {tag}', shell=True) + + # Run the Triton server and capture the container ID + container_id = subprocess.check_output( + f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models', + shell=True).decode('utf-8').strip() + + # Wait for the Triton server to start + triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False) + + # Wait until model is ready + for _ in range(10): + with contextlib.suppress(Exception): + assert triton_client.is_model_ready(model_name) + break + time.sleep(1) + + # Check Triton inference + YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference + + # Kill and remove the container at the end of the test + subprocess.call(f'docker kill {container_id}', shell=True) + + +@pytest.mark.skipif(not check_requirements('pycocotools', install=False), reason='pycocotools not installed') +def test_pycocotools(): + """Validate model predictions using pycocotools.""" + from ultralytics.models.yolo.detect import DetectionValidator + from ultralytics.models.yolo.pose import PoseValidator + from ultralytics.models.yolo.segment import SegmentationValidator + + # Download annotations after each dataset downloads first + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + + args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64} + validator = DetectionValidator(args=args) + validator() + validator.is_coco = True + download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations') + _ = validator.eval_json(validator.stats) + + args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64} + validator = SegmentationValidator(args=args) + validator() + validator.is_coco = True + download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations') + _ = validator.eval_json(validator.stats) + + args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64} + validator = PoseValidator(args=args) + validator() + validator.is_coco = True + download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations') + _ = validator.eval_json(validator.stats) diff --git a/tests/test_python.py b/tests/test_python.py index bea8afe1..395e756d 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -495,50 +495,7 @@ def test_hub(): @pytest.mark.slow @pytest.mark.skipif(not ONLINE, reason='environment is offline') -def test_triton(): - """Test NVIDIA Triton Server functionalities.""" - checks.check_requirements('tritonclient[all]') - import subprocess - import time - - from tritonclient.http import InferenceServerClient # noqa - - # Create variables - model_name = 'yolo' - triton_repo_path = TMP / 'triton_repo' - triton_model_path = triton_repo_path / model_name - - # Export model to ONNX - f = YOLO(MODEL).export(format='onnx', dynamic=True) - - # Prepare Triton repo - (triton_model_path / '1').mkdir(parents=True, exist_ok=True) - Path(f).rename(triton_model_path / '1' / 'model.onnx') - (triton_model_path / 'config.pdtxt').touch() - - # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver - tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB - - # Pull the image - subprocess.call(f'docker pull {tag}', shell=True) - - # Run the Triton server and capture the container ID - container_id = subprocess.check_output( - f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models', - shell=True).decode('utf-8').strip() - - # Wait for the Triton server to start - triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False) - - # Wait until model is ready - for _ in range(10): - with contextlib.suppress(Exception): - assert triton_client.is_model_ready(model_name) - break - time.sleep(1) - - # Check Triton inference - YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference - - # Kill and remove the container at the end of the test - subprocess.call(f'docker kill {container_id}', shell=True) +def test_model_tune(): + """Tune YOLO model for performance.""" + YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') + YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index b28efe40..0b882285 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license -__version__ = '8.0.198' +__version__ = '8.0.199' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 462280a6..0762c5fb 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -61,8 +61,8 @@ class BaseDataset(Dataset): single_cls=False, classes=None, fraction=1.0): - super().__init__() """Initialize BaseDataset with given configuration and options.""" + super().__init__() self.img_path = img_path self.imgsz = imgsz self.augment = augment @@ -85,7 +85,7 @@ class BaseDataset(Dataset): self.buffer = [] # buffer size = batch size self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 - # Cache stuff + # Cache images if cache == 'ram' and not self.check_cache_ram(): cache = False self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni @@ -123,7 +123,7 @@ class BaseDataset(Dataset): return im_files def update_labels(self, include_class: Optional[list]): - """include_class, filter labels to include only these classes (optional).""" + """Update labels to include only these classes (optional).""" include_class_array = np.array(include_class).reshape(1, -1) for i in range(len(self.labels)): if include_class is not None: @@ -146,11 +146,17 @@ class BaseDataset(Dataset): im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] if im is None: # not cached in RAM if fn.exists(): # load npy - im = np.load(fn) + try: + im = np.load(fn) + except Exception as e: + LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}') + Path(fn).unlink(missing_ok=True) + im = cv2.imread(f) # BGR else: # read image im = cv2.imread(f) # BGR - if im is None: - raise FileNotFoundError(f'Image Not Found {f}') + if im is None: + raise FileNotFoundError(f'Image Not Found {f}') + h0, w0 = im.shape[:2] # orig hw if rect_mode: # resize long side to imgsz while maintaining aspect ratio r = self.imgsz / max(h0, w0) # ratio diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index fa7d484e..6e582a8b 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -1,5 +1,12 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license -"""RT-DETR model interface.""" +""" +Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time +performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient +hybrid encoder and IoU-aware query selection for enhanced detection accuracy. + +For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf +""" + from ultralytics.engine.model import Model from ultralytics.nn.tasks import RTDETRDetectionModel @@ -9,17 +16,36 @@ from .val import RTDETRValidator class RTDETR(Model): - """RTDETR model interface.""" + """ + Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance + with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed. + + Attributes: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. + """ def __init__(self, model='rtdetr-l.pt') -> None: - """Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'.""" + """ + Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats. + + Args: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. + + Raises: + NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. + """ if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'): - raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.') + raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.') super().__init__(model=model, task='detect') @property - def task_map(self): - """Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model.""" + def task_map(self) -> dict: + """ + Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. + + Returns: + dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. + """ return { 'detect': { 'predictor': RTDETRPredictor, diff --git a/ultralytics/models/rtdetr/predict.py b/ultralytics/models/rtdetr/predict.py index 1a2b0cbc..8ad92de8 100644 --- a/ultralytics/models/rtdetr/predict.py +++ b/ultralytics/models/rtdetr/predict.py @@ -10,7 +10,11 @@ from ultralytics.utils import ops class RTDETRPredictor(BasePredictor): """ - A class extending the BasePredictor class for prediction based on an RT-DETR detection model. + RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using + Baidu's RT-DETR model. + + This class leverages the power of Vision Transformers to provide real-time object detection while maintaining + high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection. Example: ```python @@ -21,10 +25,27 @@ class RTDETRPredictor(BasePredictor): predictor = RTDETRPredictor(overrides=args) predictor.predict_cli() ``` + + Attributes: + imgsz (int): Image size for inference (must be square and scale-filled). + args (dict): Argument overrides for the predictor. """ def postprocess(self, preds, img, orig_imgs): - """Postprocess predictions and returns a list of Results objects.""" + """ + Postprocess the raw predictions from the model to generate bounding boxes and confidence scores. + + The method filters detections based on confidence and class if specified in `self.args`. + + Args: + preds (torch.Tensor): Raw predictions from the model. + img (torch.Tensor): Processed input images. + orig_imgs (list or torch.Tensor): Original, unprocessed images. + + Returns: + (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, + and class labels. + """ nd = preds[0].shape[-1] bboxes, scores = preds[0].split((4, nd - 4), dim=-1) @@ -49,15 +70,14 @@ class RTDETRPredictor(BasePredictor): def pre_transform(self, im): """ - Pre-transform input image before inference. + Pre-transforms the input images before feeding them into the model for inference. The input images are + letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled. Args: - im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. - - Notes: The size must be square(640) and scaleFilled. + im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list. Returns: - (list): A list of transformed imgs. + (list): List of pre-transformed images ready for model inference. """ letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True) return [letterbox(image=x) for x in im] diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py index 91d4729e..26b7ea68 100644 --- a/ultralytics/models/rtdetr/train.py +++ b/ultralytics/models/rtdetr/train.py @@ -13,10 +13,12 @@ from .val import RTDETRDataset, RTDETRValidator class RTDETRTrainer(DetectionTrainer): """ - A class extending the DetectionTrainer class for training based on an RT-DETR detection model. + Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer + class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision + Transformers and has capabilities like IoU-aware query selection and adaptable inference speed. Notes: - - F.grid_sample used in rt-detr does not support the `deterministic=True` argument. + - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching. Example: @@ -30,7 +32,17 @@ class RTDETRTrainer(DetectionTrainer): """ def get_model(self, cfg=None, weights=None, verbose=True): - """Return a YOLO detection model.""" + """ + Initialize and return an RT-DETR model for object detection tasks. + + Args: + cfg (dict, optional): Model configuration. Defaults to None. + weights (str, optional): Path to pre-trained model weights. Defaults to None. + verbose (bool): Verbose logging if True. Defaults to True. + + Returns: + (RTDETRDetectionModel): Initialized model. + """ model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) if weights: model.load(weights) @@ -38,31 +50,46 @@ class RTDETRTrainer(DetectionTrainer): def build_dataset(self, img_path, mode='val', batch=None): """ - Build RTDETR Dataset. + Build and return an RT-DETR dataset for training or validation. Args: img_path (str): Path to the folder containing images. - mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. - batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + mode (str): Dataset mode, either 'train' or 'val'. + batch (int, optional): Batch size for rectangle training. Defaults to None. + + Returns: + (RTDETRDataset): Dataset object for the specific mode. """ - return RTDETRDataset( - img_path=img_path, - imgsz=self.args.imgsz, - batch_size=batch, - augment=mode == 'train', # no augmentation - hyp=self.args, - rect=False, # no rect - cache=self.args.cache or None, - prefix=colorstr(f'{mode}: '), - data=self.data) + return RTDETRDataset(img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=mode == 'train', + hyp=self.args, + rect=False, + cache=self.args.cache or None, + prefix=colorstr(f'{mode}: '), + data=self.data) def get_validator(self): - """Returns a DetectionValidator for RTDETR model validation.""" + """ + Returns a DetectionValidator suitable for RT-DETR model validation. + + Returns: + (RTDETRValidator): Validator object for model validation. + """ self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) def preprocess_batch(self, batch): - """Preprocesses a batch of images by scaling and converting to float.""" + """ + Preprocess a batch of images. Scales and converts the images to float format. + + Args: + batch (dict): Dictionary containing a batch of images, bboxes, and labels. + + Returns: + (dict): Preprocessed batch. + """ batch = super().preprocess_batch(batch) bs = len(batch['img']) batch_idx = batch['batch_idx'] diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index 35f4efa8..abf2eef5 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -3,6 +3,4 @@ from .model import SAM from .predict import Predictor -# from .build import build_sam - __all__ = 'SAM', 'Predictor' # tuple or list diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index 8a140b3f..40871044 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -1,5 +1,18 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license -"""SAM model interface.""" +""" +SAM model interface. + +This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image +segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis, +and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new +image distributions and tasks without prior knowledge. + +Key Features: + - Promptable segmentation + - Real-time performance + - Zero-shot transfer capabilities + - Trained on SA-1B dataset +""" from pathlib import Path @@ -11,40 +24,94 @@ from .predict import Predictor class SAM(Model): - """SAM model interface.""" + """ + SAM (Segment Anything Model) interface class. + + SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as + bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B + dataset. + """ def __init__(self, model='sam_b.pt') -> None: - """Initializes the SAM model instance with the specified pre-trained model file.""" + """ + Initializes the SAM model with a pre-trained model file. + + Args: + model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension. + + Raises: + NotImplementedError: If the model file extension is not .pt or .pth. + """ if model and Path(model).suffix not in ('.pt', '.pth'): raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.') super().__init__(model=model, task='segment') def _load(self, weights: str, task=None): - """Loads the provided weights into the SAM model.""" + """ + Loads the specified weights into the SAM model. + + Args: + weights (str): Path to the weights file. + task (str, optional): Task name. Defaults to None. + """ self.model = build_sam(weights) def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): - """Predicts and returns segmentation masks for given image or video source.""" + """ + Performs segmentation prediction on the given image or video source. + + Args: + source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. + stream (bool, optional): If True, enables real-time streaming. Defaults to False. + bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. + points (list, optional): List of points for prompted segmentation. Defaults to None. + labels (list, optional): List of labels for prompted segmentation. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + The segmentation masks. + """ overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) kwargs.update(overrides) prompts = dict(bboxes=bboxes, points=points, labels=labels) return super().predict(source, stream, prompts=prompts, **kwargs) def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" + """ + Alias for the 'predict' method. + + Args: + source: Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. + stream (bool, optional): If True, enables real-time streaming. Defaults to False. + bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. + points (list, optional): List of points for prompted segmentation. Defaults to None. + labels (list, optional): List of labels for prompted segmentation. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + The segmentation masks. + """ return self.predict(source, stream, bboxes, points, labels, **kwargs) def info(self, detailed=False, verbose=True): """ - Logs model info. + Logs information about the SAM model. Args: - detailed (bool): Show detailed information about model. - verbose (bool): Controls verbosity. + detailed (bool, optional): If True, displays detailed information about the model. Defaults to False. + verbose (bool, optional): If True, displays information on the console. Defaults to True. + + Returns: + (tuple): A tuple containing the model's information. """ return model_info(self.model, detailed=detailed, verbose=verbose) @property def task_map(self): - """Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'.""" + """ + Provides a mapping from the 'segment' task to its corresponding 'Predictor'. + + Returns: + dict: A dictionary mapping the 'segment' task to its corresponding 'Predictor'. + """ return {'segment': {'predictor': Predictor}} diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 26a49b54..8dce2be7 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -1,4 +1,12 @@ # Ultralytics YOLO πŸš€, AGPL-3.0 license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" import numpy as np import torch @@ -18,71 +26,86 @@ from .build import build_sam class Predictor(BasePredictor): """ - A prediction class for segmentation tasks, extending the BasePredictor. + Predictor class for the Segment Anything Model (SAM), extending BasePredictor. - This class serves as an interface for model inference for segmentation tasks. - It can preprocess input images, perform inference, and postprocess the output. - It also supports handling various types of input prompts including bounding boxes, - points, and low-resolution masks for better prediction results. + The class provides an interface for model inference tailored to image segmentation tasks. + With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time + mask generation. The class is capable of working with various types of prompts such as bounding boxes, + points, and low-resolution masks. Attributes: - cfg (dict): Configuration dictionary. - overrides (dict): Dictionary of overriding values. - _callbacks (dict): Dictionary of callback functions. - args (namespace): Argument namespace. - im (torch.Tensor): Preprocessed image for current prediction. - features (torch.Tensor): Image features. - prompts (dict): Dictionary of prompts like bboxes, points, masks. - segment_all (bool): Whether to perform segmentation on all objects or not. + cfg (dict): Configuration dictionary specifying model and task-related parameters. + overrides (dict): Dictionary containing values that override the default configuration. + _callbacks (dict): Dictionary of user-defined callback functions to augment behavior. + args (namespace): Namespace to hold command-line arguments or other operational variables. + im (torch.Tensor): Preprocessed input image tensor. + features (torch.Tensor): Extracted image features used for inference. + prompts (dict): Collection of various prompt types, such as bounding boxes and points. + segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): - """Initializes the Predictor class with default or provided configuration, overrides, and callbacks.""" + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It + initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results. + + Args: + cfg (dict): Configuration dictionary. + overrides (dict, optional): Dictionary of values to override default configuration. + _callbacks (dict, optional): Dictionary of callback functions to customize behavior. + """ if overrides is None: overrides = {} overrides.update(dict(task='segment', mode='predict', imgsz=1024)) super().__init__(cfg, overrides, _callbacks) - # SAM needs retina_masks=True, or the results would be a mess. self.args.retina_masks = True - # Args for set_image self.im = None self.features = None - # Args for set_prompts self.prompts = {} - # Args for segment everything self.segment_all = False def preprocess(self, im): """ - Prepares input image before inference. + Preprocess the input image for model inference. + + The method prepares the input image by applying transformations and normalization. + It supports both torch.Tensor and list of np.ndarray as input formats. Args: - im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays. + + Returns: + torch.Tensor: The preprocessed image tensor. """ if self.im is not None: return self.im not_tensor = not isinstance(im, torch.Tensor) if not_tensor: im = np.stack(self.pre_transform(im)) - im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) - im = np.ascontiguousarray(im) # contiguous + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) im = torch.from_numpy(im) im = im.to(self.device) - im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + im = im.half() if self.model.fp16 else im.float() if not_tensor: im = (im - self.mean) / self.std return im def pre_transform(self, im): """ - Pre-transform input image before inference. + Perform initial transformations on the input image for preprocessing. + + The method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. Args: - im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + im (List[np.ndarray]): List containing images in HWC numpy array format. Returns: - (list): A list of transformed images. + List[np.ndarray]: List of transformed images. """ assert len(im) == 1, 'SAM model does not currently support batched inference' letterbox = LetterBox(self.args.imgsz, auto=False, center=False) @@ -90,69 +113,52 @@ class Predictor(BasePredictor): def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): """ - Predict masks for the given input prompts, using the currently set image. + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - bboxes (np.ndarray | List, None): (N, 4), in XYXY format. - points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. - labels (np.ndarray | List, None): (N, ), labels for the point prompts. - 1 indicates a foreground point and 0 indicates a background point. - masks (np.ndarray, None): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form (N, H, W), where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. + multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. + tuple: Contains the following three elements. + - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. + - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. + - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ - # Get prompts from self.prompts first + # Override prompts if any stored in self.prompts bboxes = self.prompts.pop('bboxes', bboxes) points = self.prompts.pop('points', points) masks = self.prompts.pop('masks', masks) + if all(i is None for i in [bboxes, points, masks]): return self.generate(im, *args, **kwargs) + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): """ - Predict masks for the given input prompts, using the currently set image. + Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. + Leverages SAM's specialized architecture for prompt-based, real-time segmentation. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - bboxes (np.ndarray | List, None): (N, 4), in XYXY format. - points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. - labels (np.ndarray | List, None): (N, ), labels for the point prompts. - 1 indicates a foreground point and 0 indicates a background point. - masks (np.ndarray, None): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form (N, H, W), where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. + multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. + tuple: Contains the following three elements. + - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. + - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. + - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ features = self.model.image_encoder(im) if self.features is None else self.features @@ -178,11 +184,7 @@ class Predictor(BasePredictor): points = (points, labels) if points is not None else None # Embed prompts - sparse_embeddings, dense_embeddings = self.model.prompt_encoder( - points=points, - boxes=bboxes, - masks=masks, - ) + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) # Predict masks pred_masks, pred_scores = self.model.mask_decoder( @@ -210,46 +212,35 @@ class Predictor(BasePredictor): stability_score_offset=0.95, crop_nms_thresh=0.7): """ - Segment the whole image. + Perform image segmentation using the Segment Anything Model (SAM). + + This function segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. Args: - im (torch.Tensor): The preprocessed image, (N, C, H, W). - crop_n_layers (int): If >0, mask prediction will be run again on - crops of the image. Sets the number of layers to run, where each - layer has 2**i_layer number of image crops. - crop_overlap_ratio (float): Sets the degree to which crops overlap. - In the first crop layer, crops will overlap by this fraction of - the image length. Later layers with more crops scale down this overlap. - crop_downscale_factor (int): The number of points-per-side - sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - point_grids (list(np.ndarray), None): A list over explicit grids - of points used for sampling, normalized to [0,1]. The nth grid in the - list is used in the nth crop layer. Exclusive with points_per_side. - points_stride (int, None): The number of points to be sampled - along one side of the image. The total number of points is - points_per_side**2. If None, 'point_grids' must provide explicit - point sampling. - points_batch_size (int): Sets the number of points run simultaneously - by the model. Higher numbers may be faster but use more GPU memory. - conf_thres (float): A filtering threshold in [0,1], using the - model's predicted mask quality. - stability_score_thresh (float): A filtering threshold in [0,1], using - the stability of the mask under changes to the cutoff used to binarize - the model's mask predictions. - stability_score_offset (float): The amount to shift the cutoff when - calculated the stability score. - crop_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks between different crops. + im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W). + crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops. + Each layer produces 2**i_layer number of image crops. + crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer. + point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1]. + Used in the nth crop layer. + points_stride (int, optional): Number of points to sample along each side of the image. + Exclusive with 'point_grids'. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops. + + Returns: + tuple: A tuple containing segmented masks, confidence scores, and bounding boxes. """ self.segment_all = True ih, iw = im.shape[2:] crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) if point_grids is None: - point_grids = build_all_layer_point_grids( - points_stride, - crop_n_layers, - crop_downscale_factor, - ) + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] for crop_region, layer_idx in zip(crop_regions, layer_idxs): x1, y1, x2, y2 = crop_region @@ -312,7 +303,22 @@ class Predictor(BasePredictor): return pred_masks, pred_scores, pred_bboxes def setup_model(self, model, verbose=True): - """Set up YOLO model with specified thresholds and device.""" + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. + verbose (bool): If True, prints selected device information. + + Attributes: + model (torch.nn.Module): The SAM model allocated to the chosen device for inference. + device (torch.device): The device to which the model and tensors are allocated. + mean (torch.Tensor): The mean values for image normalization. + std (torch.Tensor): The standard deviation values for image normalization. + """ device = select_device(self.args.device, verbose=verbose) if model is None: model = build_sam(self.args.model) @@ -321,7 +327,8 @@ class Predictor(BasePredictor): self.device = device self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) - # TODO: Temporary settings for compatibility + + # Ultralytics compatibility settings self.model.pt = False self.model.triton = False self.model.stride = 32 @@ -329,7 +336,20 @@ class Predictor(BasePredictor): self.done_warmup = True def postprocess(self, preds, img, orig_imgs): - """Post-processes inference output predictions to create detection masks for objects.""" + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The + SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance. + + Args: + preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes. + img (torch.Tensor): The processed input image tensor. + orig_imgs (list | torch.Tensor): The original, unprocessed images. + + Returns: + (list): List of Results objects containing detection masks, bounding boxes, and other metadata. + """ # (N, 1, H, W), (N, 1) pred_masks, pred_scores = preds[:2] pred_bboxes = preds[2] if self.segment_all else None @@ -355,15 +375,30 @@ class Predictor(BasePredictor): return results def setup_source(self, source): - """Sets up source and inference mode.""" + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. The source could be a + directory, a video file, or other types of image data sources. + + Args: + source (str | Path): The path to the image data source for inference. + """ if source is not None: super().setup_source(source) def set_image(self, image): - """Set image in advance. - Args: + """ + Preprocesses and sets a single image for inference. - image (str | np.ndarray): image file path or np.ndarray image by cv2. + This function sets up the model if not already initialized, configures the data source to the specified image, + and preprocesses the image for feature extraction. Only one image can be set at a time. + + Args: + image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2. + + Raises: + AssertionError: If more than one image is set. """ if self.model is None: model = build_sam(self.args.model) @@ -388,17 +423,20 @@ class Predictor(BasePredictor): @staticmethod def remove_small_regions(masks, min_area=0, nms_thresh=0.7): """ - Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates. - Requires open-cv as a dependency. + Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this + function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. Args: - masks (torch.Tensor): Masks, (N, H, W). - min_area (int): Minimum area threshold. - nms_thresh (float): NMS threshold. + masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is + the number of masks, H is height, and W is width. + min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0. + nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7. + Returns: - new_masks (torch.Tensor): New Masks, (N, H, W). - keep (List[int]): The indices of the new masks, which can be used to filter - the corresponding boxes. + T(uple[torch.Tensor, List[int]]): + - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W). + - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes. """ if len(masks) == 0: return masks @@ -420,10 +458,6 @@ class Predictor(BasePredictor): # Recalculate boxes and remove any new duplicates new_masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(new_masks) - keep = torchvision.ops.nms( - boxes.float(), - torch.as_tensor(scores), - nms_thresh, - ) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep