diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 2b5bf92f..0d6d8959 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -241,7 +241,7 @@ jobs:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
Conda:
- if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule-disabled' || github.event.inputs.conda == 'true')
+ if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule' || github.event.inputs.conda == 'true')
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
diff --git a/docker/Dockerfile b/docker/Dockerfile
index d97632c6..5973417a 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -3,7 +3,7 @@
# Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3
-FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
+FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
RUN pip install --no-cache nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com
# Downloads to user config dir
diff --git a/docs/guides/azureml-quickstart.md b/docs/guides/azureml-quickstart.md
index 5119c1a1..74c7cb57 100644
--- a/docs/guides/azureml-quickstart.md
+++ b/docs/guides/azureml-quickstart.md
@@ -77,7 +77,7 @@ Train a detection model for 10 epochs with an initial learning_rate of 0.01:
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
```
-You can find more [instructions to use the Ultralytics cli here](https://docs.ultralytics.com/quickstart/#use-ultralytics-with-cli).
+You can find more [instructions to use the Ultralytics CLI here](https://docs.ultralytics.com/quickstart/#use-ultralytics-with-cli).
## Quickstart from a Notebook
diff --git a/docs/guides/index.md b/docs/guides/index.md
index 71bf350b..7204d3f5 100644
--- a/docs/guides/index.md
+++ b/docs/guides/index.md
@@ -22,6 +22,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
* [Conda Quickstart](conda-quickstart.md) 🚀 NEW: Step-by-step guide to setting up a [Conda](https://anaconda.org/conda-forge/ultralytics) environment for Ultralytics. Learn how to install and start using the Ultralytics package efficiently with Conda.
* [Docker Quickstart](docker-quickstart.md) 🚀 NEW: Complete guide to setting up and using Ultralytics YOLO models with [Docker](https://hub.docker.com/r/ultralytics/ultralytics). Learn how to install Docker, manage GPU support, and run YOLO models in isolated containers for consistent development and deployment.
* [Raspberry Pi](raspberry-pi.md) 🚀 NEW: Quickstart tutorial to run YOLO models to the latest Raspberry Pi hardware.
+* [Triton Inference Server Integration](triton-inference-server.md) 🚀 NEW: Dive into the integration of Ultralytics YOLOv8 with NVIDIA's Triton Inference Server for scalable and efficient deep learning inference deployments.
## Contribute to Our Guides
diff --git a/docs/guides/raspberry-pi.md b/docs/guides/raspberry-pi.md
index 2cb985f7..c46a6604 100644
--- a/docs/guides/raspberry-pi.md
+++ b/docs/guides/raspberry-pi.md
@@ -37,47 +37,25 @@ You should see a video feed from your camera.
This guide offers you the flexibility to start with either [YOLOv5](https://github.com/ultralytics/yolov5) or [YOLOv8](https://github.com/ultralytics/ultralytics). Both versions have their unique advantages and use-cases. The choice is yours, but remember, the guide's aim is not just quick setup but also a robust foundation for your future work in object detection.
-## Hardware Specifics: Raspberry Pi 3 vs Raspberry Pi 4
+## Hardware Specifics: At a Glance
-Raspberry Pi 3 and Raspberry Pi 4 have distinct hardware specifications, and the YOLO installation and configuration process can vary slightly depending on which model you're using.
+To assist you in making an informed hardware decision, we've summarized the key hardware specifics of Raspberry Pi 3, 4, and 5 in the table below:
-### Raspberry Pi 3
-
-- **CPU**: 1.2GHz Quad-Core ARM Cortex-A53
-- **RAM**: 1GB LPDDR2
-- **USB Ports**: 4 x USB 2.0
-- **Network**: Ethernet & Wi-Fi 802.11n
-- **Performance**: Generally slower, may require lighter YOLO models for real-time processing
-- **Power Requirement**: 2.5A power supply
-- **Official Documentation**: [Raspberry Pi 3 Documentation](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2837/README.md)
-
-### Raspberry Pi 4
-
-- **CPU**: 1.5GHz Quad-core 64-bit ARM Cortex-A72 CPU
-- **RAM**: Options of 2GB, 4GB or 8GB LPDDR4
-- **USB Ports**: 2 x USB 2.0, 2 x USB 3.0
-- **Network**: Gigabit Ethernet & Wi-Fi 802.11ac
-- **Performance**: Faster, capable of running more complex YOLO models in real-time
-- **Power Requirement**: 3.0A USB-C power supply
-- **Official Documentation**: [Raspberry Pi 4 Documentation](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2711/README.md)
-
-### Raspberry Pi 5
-
-- **CPU**: 2.4GHz Quad-core 64-bit Arm Cortex-A76 CPU
-- **GPU**: VideoCore VII, supporting OpenGL ES 3.1, Vulkan 1.2
-- **Display Output**: Dual 4Kp60 HDMI
-- **Decoder**: 4Kp60 HEVC
-- **Network**: Gigabit Ethernet with PoE+ support, Dual-band 802.11ac Wi-Fi®, Bluetooth 5.0 / BLE
-- **USB Ports**: 2 x USB 3.0, 2 x USB 2.0
-- **Other Features**: High-speed microSD card interface with SDR104 mode, 2 × 4-lane MIPI camera/display transceivers, PCIe 2.0 x1 interface, standard 40-pin GPIO header, real-time clock, power button
-- **Power Requirement**: Specifics not yet available, expected to require a higher amperage supply
-- **Official Documentation**: [Raspberry Pi 5 Documentation](https://www.raspberrypi.com/news/introducing-raspberry-pi-5/)
+| Feature | Raspberry Pi 3 | Raspberry Pi 4 | Raspberry Pi 5 |
+|----------------------------|------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------|----------------------------------------------------------------------|
+| **CPU** | 1.2GHz Quad-Core ARM Cortex-A53 | 1.5GHz Quad-core 64-bit ARM Cortex-A72 | 2.4GHz Quad-core 64-bit Arm Cortex-A76 |
+| **RAM** | 1GB LPDDR2 | 2GB, 4GB or 8GB LPDDR4 | *Details not yet available* |
+| **USB Ports** | 4 x USB 2.0 | 2 x USB 2.0, 2 x USB 3.0 | 2 x USB 3.0, 2 x USB 2.0 |
+| **Network** | Ethernet & Wi-Fi 802.11n | Gigabit Ethernet & Wi-Fi 802.11ac | Gigabit Ethernet with PoE+ support, Dual-band 802.11ac Wi-Fi® |
+| **Performance** | Slower, may require lighter YOLO models | Faster, can run complex YOLO models | *Details not yet available* |
+| **Power Requirement** | 2.5A power supply | 3.0A USB-C power supply | *Details not yet available* |
+| **Official Documentation** | [Link](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2837/README.md) | [Link](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2711/README.md) | [Link](https://www.raspberrypi.com/news/introducing-raspberry-pi-5/) |
Please make sure to follow the instructions specific to your Raspberry Pi model to ensure a smooth setup process.
## Quick Start with YOLOv5
-This section outlines how to set up YOLOv5 on a Raspberry Pi 3 or 4 with a Pi Camera. These steps are designed to be compatible with the libcamera camera stack introduced in Raspberry Pi OS Bullseye.
+This section outlines how to set up YOLOv5 on a Raspberry Pi with a Pi Camera. These steps are designed to be compatible with the libcamera camera stack introduced in Raspberry Pi OS Bullseye.
### Install Necessary Packages
@@ -171,7 +149,7 @@ Follow this section if you are interested in setting up YOLOv8 instead. The step
sudo apt-get autoremove -y
```
-2. Install YOLOv8:
+2. Install the `ultralytics` Python package:
```bash
pip3 install ultralytics
@@ -183,28 +161,6 @@ Follow this section if you are interested in setting up YOLOv8 instead. The step
sudo reboot
```
-### Modify `build.py`
-
-Just like YOLOv5, YOLOv8 also needs minor modifications to accept TCP streams.
-
-1. Open `build.py` located in the Ultralytics package folder:
-
- ```bash
- sudo nano /home/pi/.local/lib/pythonX.X/site-packages/ultralytics/build.py
- ```
-
-2. Find and modify the `is_url` line to accept TCP streams:
-
- ```python
- is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://', 'tcp://'))
- ```
-
-3. Save and exit:
-
- ```bash
- CTRL + O -> ENTER -> CTRL + X
- ```
-
### Initiate TCP Stream with Libcamera
1. Start the TCP stream:
@@ -231,7 +187,7 @@ while True:
## Next Steps
-Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics](https://ultralytics.com/) and [KashmirWorldFoundation](https://www.kashmirworldfoundation.org/).
+Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics](https://ultralytics.com/) and [Kashmir World Foundation](https://www.kashmirworldfoundation.org/).
## Acknowledgements and Citations
diff --git a/docs/guides/triton-inference-server.md b/docs/guides/triton-inference-server.md
new file mode 100644
index 00000000..b8bcfab4
--- /dev/null
+++ b/docs/guides/triton-inference-server.md
@@ -0,0 +1,137 @@
+---
+comments: true
+description: A step-by-step guide on integrating Ultralytics YOLOv8 with Triton Inference Server for scalable and high-performance deep learning inference deployments.
+keywords: YOLOv8, Triton Inference Server, ONNX, Deep Learning Deployment, Scalable Inference, Ultralytics, NVIDIA, Object Detection, Cloud Inferencing
+---
+
+# Triton Inference Server with Ultralytics YOLOv8
+
+The [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server) (formerly known as TensorRT Inference Server) is an open-source software solution developed by NVIDIA. It provides a cloud inferencing solution optimized for NVIDIA GPUs. Triton simplifies the deployment of AI models at scale in production. Integrating Ultralytics YOLOv8 with Triton Inference Server allows you to deploy scalable, high-performance deep learning inference workloads. This guide provides steps to set up and test the integration.
+
+
+
+
+
+ Watch: Getting Started with NVIDIA Triton Inference Server.
+
+
+## What is Triton Inference Server?
+
+Triton Inference Server is designed to deploy a variety of AI models in production. It supports a wide range of deep learning and machine learning frameworks, including TensorFlow, PyTorch, ONNX Runtime, and many others. Its primary use cases are:
+
+- Serving multiple models from a single server instance.
+- Dynamic model loading and unloading without server restart.
+- Ensemble inferencing, allowing multiple models to be used together to achieve results.
+- Model versioning for A/B testing and rolling updates.
+
+## Prerequisites
+
+Ensure you have the following prerequisites before proceeding:
+
+- Docker installed on your machine.
+- Install `tritonclient`:
+ ```bash
+ pip install tritonclient[all]
+ ```
+
+## Exporting YOLOv8 to ONNX Format
+
+Before deploying the model on Triton, it must be exported to the ONNX format. ONNX (Open Neural Network Exchange) is a format that allows models to be transferred between different deep learning frameworks. Use the `export` function from the `YOLO` class:
+
+```python
+from ultralytics import YOLO
+
+# Load a model
+model = YOLO('yolov8n.pt') # load an official model
+
+# Export the model
+onnx_file = model.export(format='onnx', dynamic=True)
+```
+
+## Setting Up Triton Model Repository
+
+The Triton Model Repository is a storage location where Triton can access and load models.
+
+1. Create the necessary directory structure:
+
+ ```python
+ from pathlib import Path
+
+ # Define paths
+ triton_repo_path = Path('tmp') / 'triton_repo'
+ triton_model_path = triton_repo_path / 'yolo'
+
+ # Create directories
+ (triton_model_path / '1').mkdir(parents=True, exist_ok=True)
+ ```
+
+2. Move the exported ONNX model to the Triton repository:
+
+ ```python
+ from pathlib import Path
+
+ # Move ONNX model to Triton Model path
+ Path(onnx_file).rename(triton_model_path / '1' / 'model.onnx')
+
+ # Create config file
+ (triton_model_path / 'config.pdtxt').touch()
+ ```
+
+## Running Triton Inference Server
+
+Run the Triton Inference Server using Docker:
+
+```python
+import subprocess
+import time
+
+from tritonclient.http import InferenceServerClient
+
+# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
+tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
+
+# Pull the image
+subprocess.call(f'docker pull {tag}', shell=True)
+
+# Run the Triton server and capture the container ID
+container_id = subprocess.check_output(
+ f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
+ shell=True).decode('utf-8').strip()
+
+# Wait for the Triton server to start
+triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
+
+# Wait until model is ready
+for _ in range(10):
+ with contextlib.suppress(Exception):
+ assert triton_client.is_model_ready(model_name)
+ break
+ time.sleep(1)
+```
+
+Then run inference using the Triton Server model:
+
+```python
+from ultralytics import YOLO
+
+# Load the Triton Server model
+model = YOLO(f'http://localhost:8000/yolo', task='detect')
+
+# Run inference on the server
+results = model('path/to/image.jpg')
+```
+
+Cleanup the container:
+
+```python
+# Kill and remove the container at the end of the test
+subprocess.call(f'docker kill {container_id}', shell=True)
+```
+
+---
+
+By following the above steps, you can deploy and run Ultralytics YOLOv8 models efficiently on Triton Inference Server, providing a scalable and high-performance solution for deep learning inference tasks. If you face any issues or have further queries, refer to the [official Triton documentation](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html) or reach out to the Ultralytics community for support.
\ No newline at end of file
diff --git a/docs/modes/export.md b/docs/modes/export.md
index 1ef79f4d..6c2c1e1d 100644
--- a/docs/modes/export.md
+++ b/docs/modes/export.md
@@ -57,7 +57,7 @@ Export a YOLOv8n model to a different format like ONNX or TensorRT. See Argument
# Load a model
model = YOLO('yolov8n.pt') # load an official model
- model = YOLO('path/to/best.pt') # load a custom trained
+ model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')
diff --git a/docs/reference/utils/triton.md b/docs/reference/utils/triton.md
new file mode 100644
index 00000000..df56b2ff
--- /dev/null
+++ b/docs/reference/utils/triton.md
@@ -0,0 +1,9 @@
+# Reference for `ultralytics/utils/triton.py`
+
+!!! note
+
+ Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/triton.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/triton.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠️. Thank you 🙏!
+
+---
+## ::: ultralytics.utils.triton.TritonRemoteModel
+
diff --git a/docs/tasks/classify.md b/docs/tasks/classify.md
index 708287b4..ab1e9af4 100644
--- a/docs/tasks/classify.md
+++ b/docs/tasks/classify.md
@@ -140,7 +140,7 @@ Export a YOLOv8n-cls model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-cls.pt') # load an official model
- model = YOLO('path/to/best.pt') # load a custom trained
+ model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')
diff --git a/docs/tasks/detect.md b/docs/tasks/detect.md
index 0cd30e2d..ba14278c 100644
--- a/docs/tasks/detect.md
+++ b/docs/tasks/detect.md
@@ -152,7 +152,7 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n.pt') # load an official model
- model = YOLO('path/to/best.pt') # load a custom trained
+ model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')
diff --git a/docs/tasks/pose.md b/docs/tasks/pose.md
index 5c3035b7..5fe838d7 100644
--- a/docs/tasks/pose.md
+++ b/docs/tasks/pose.md
@@ -156,7 +156,7 @@ Export a YOLOv8n Pose model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-pose.pt') # load an official model
- model = YOLO('path/to/best.pt') # load a custom trained
+ model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')
diff --git a/docs/tasks/segment.md b/docs/tasks/segment.md
index b81465f0..f5b73330 100644
--- a/docs/tasks/segment.md
+++ b/docs/tasks/segment.md
@@ -157,7 +157,7 @@ Export a YOLOv8n-seg model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-seg.pt') # load an official model
- model = YOLO('path/to/best.pt') # load a custom trained
+ model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')
diff --git a/mkdocs.yml b/mkdocs.yml
index a8210e8d..578beec7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -223,6 +223,7 @@ nav:
- Conda Quickstart: guides/conda-quickstart.md
- Docker Quickstart: guides/docker-quickstart.md
- Raspberry Pi: guides/raspberry-pi.md
+ - Triton Inference Server: guides/triton-inference-server.md
- Integrations:
- integrations/index.md
- OpenVINO: integrations/openvino.md
@@ -390,6 +391,7 @@ nav:
- plotting: reference/utils/plotting.md
- tal: reference/utils/tal.md
- torch_utils: reference/utils/torch_utils.md
+ - triton: reference/utils/triton.md
- tuner: reference/utils/tuner.md
- Help:
diff --git a/tests/test_python.py b/tests/test_python.py
index 1ee8c8bb..3e49f570 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -15,7 +15,7 @@ from ultralytics import RTDETR, YOLO
from ultralytics.cfg import TASK2DATA
from ultralytics.data.build import load_inference_source
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_PATH, LINUX, MACOS, ONLINE, ROOT, WEIGHTS_DIR, WINDOWS,
- is_dir_writeable)
+ checks, is_dir_writeable)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
@@ -343,17 +343,14 @@ def test_utils_init():
def test_utils_checks():
- from ultralytics.utils.checks import (check_imgsz, check_imshow, check_requirements, check_version,
- check_yolov5u_filename, git_describe, print_args)
-
- check_yolov5u_filename('yolov5n.pt')
- # check_imshow(warn=True)
- git_describe(ROOT)
- check_requirements() # check requirements.txt
- check_imgsz([600, 600], max_dim=1)
- check_imshow()
- check_version('ultralytics', '8.0.0')
- print_args()
+ checks.check_yolov5u_filename('yolov5n.pt')
+ checks.git_describe(ROOT)
+ checks.check_requirements() # check requirements.txt
+ checks.check_imgsz([600, 600], max_dim=1)
+ checks.check_imshow()
+ checks.check_version('ultralytics', '8.0.0')
+ checks.print_args()
+ # checks.check_imshow(warn=True)
def test_utils_benchmarks():
@@ -451,3 +448,53 @@ def test_hub():
export_fmts_hub()
logout()
smart_request('GET', 'http://github.com', progress=True)
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(not ONLINE, reason='environment is offline')
+def test_triton():
+ checks.check_requirements('tritonclient[all]')
+ import subprocess
+ import time
+
+ from tritonclient.http import InferenceServerClient # noqa
+
+ # Create variables
+ model_name = 'yolo'
+ triton_repo_path = TMP / 'triton_repo'
+ triton_model_path = triton_repo_path / model_name
+
+ # Export model to ONNX
+ f = YOLO(MODEL).export(format='onnx', dynamic=True)
+
+ # Prepare Triton repo
+ (triton_model_path / '1').mkdir(parents=True, exist_ok=True)
+ Path(f).rename(triton_model_path / '1' / 'model.onnx')
+ (triton_model_path / 'config.pdtxt').touch()
+
+ # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
+ tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
+
+ # Pull the image
+ subprocess.call(f'docker pull {tag}', shell=True)
+
+ # Run the Triton server and capture the container ID
+ container_id = subprocess.check_output(
+ f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
+ shell=True).decode('utf-8').strip()
+
+ # Wait for the Triton server to start
+ triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
+
+ # Wait until model is ready
+ for _ in range(10):
+ with contextlib.suppress(Exception):
+ assert triton_client.is_model_ready(model_name)
+ break
+ time.sleep(1)
+
+ # Check Triton inference
+ YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
+
+ # Kill and remove the container at the end of the test
+ subprocess.call(f'docker kill {container_id}', shell=True)
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 6a87a469..d2ae241c 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.194'
+__version__ = '8.0.195'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index d8c881a9..69de12b2 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -81,6 +81,12 @@ class Model(nn.Module):
self.session = HUBTrainingSession(model)
model = self.session.model_file
+ # Check if Triton Server model
+ elif self.is_triton_model(model):
+ self.model = model
+ self.task = task
+ return
+
# Load or create new YOLO model
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
@@ -94,6 +100,13 @@ class Model(nn.Module):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
+ @staticmethod
+ def is_triton_model(model):
+ """Is model a Triton Server URL string, i.e. :////"""
+ from urllib.parse import urlsplit
+ url = urlsplit(model)
+ return url.netloc and url.path and url.scheme in {'http', 'grfc'}
+
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py
index f94a1738..4eac69f9 100644
--- a/ultralytics/models/fastsam/predict.py
+++ b/ultralytics/models/fastsam/predict.py
@@ -15,13 +15,14 @@ class FastSAMPredictor(DetectionPredictor):
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
- p = ops.non_max_suppression(preds[0],
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=len(self.model.names),
- classes=self.args.classes)
+ p = ops.non_max_suppression(
+ preds[0],
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nc=1, # set to 1 class since SAM has no class predictions
+ classes=self.args.classes)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index 633aa9db..61ca6db6 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -7,7 +7,6 @@ import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
-from urllib.parse import urlparse
import cv2
import numpy as np
@@ -32,8 +31,8 @@ def check_class_names(names):
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
- map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
- names = {k: map[v] for k, v in names.items()}
+ names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
+ names = {k: names_map[v] for k, v in names.items()}
return names
@@ -274,13 +273,9 @@ class AutoBackend(nn.Module):
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server
- """TODO
check_requirements('tritonclient[all]')
- from utils.triton import TritonRemoteModel
- model = TritonRemoteModel(url=w)
- nhwc = model.runtime.startswith("tensorflow")
- """
- raise NotImplementedError('Triton Inference Server is not currently supported.')
+ from ultralytics.utils.triton import TritonRemoteModel
+ model = TritonRemoteModel(w)
else:
from ultralytics.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. "
@@ -395,6 +390,7 @@ class AutoBackend(nn.Module):
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server
+ im = im.cpu().numpy() # torch to numpy
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.cpu().numpy()
@@ -498,6 +494,8 @@ class AutoBackend(nn.Module):
if any(types):
triton = False
else:
- url = urlparse(p) # if url may be Triton inference server
- triton = all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
+ from urllib.parse import urlsplit
+ url = urlsplit(p)
+ triton = url.netloc and url.path and url.scheme in {'http', 'grfc'}
+
return types + [triton]
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index 3021dd08..c4490136 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -705,7 +705,7 @@ def remove_colorstr(input_string):
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
>>> 'hello world'
"""
- ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\\-_]|\[[0-9]*[ -/]*[@-~])')
return ansi_escape.sub('', input_string)
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
index 86560bcc..574d4039 100644
--- a/ultralytics/utils/metrics.py
+++ b/ultralytics/utils/metrics.py
@@ -2,6 +2,7 @@
"""
Model validation metrics
"""
+
import math
import warnings
from pathlib import Path
diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py
new file mode 100644
index 00000000..c48e418a
--- /dev/null
+++ b/ultralytics/utils/triton.py
@@ -0,0 +1,86 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from typing import List
+from urllib.parse import urlsplit
+
+import numpy as np
+
+
+class TritonRemoteModel:
+ """Client for interacting with a remote Triton Inference Server model.
+
+ Attributes:
+ endpoint (str): The name of the model on the Triton server.
+ url (str): The URL of the Triton server.
+ triton_client: The Triton client (either HTTP or gRPC).
+ InferInput: The input class for the Triton client.
+ InferRequestedOutput: The output request class for the Triton client.
+ input_formats (List[str]): The data types of the model inputs.
+ np_input_formats (List[type]): The numpy data types of the model inputs.
+ input_names (List[str]): The names of the model inputs.
+ output_names (List[str]): The names of the model outputs.
+ """
+
+ def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
+ """
+ Initialize the TritonRemoteModel.
+
+ Arguments may be provided individually or parsed from a collective 'url' argument of the form
+ :////
+
+ Args:
+ url (str): The URL of the Triton server.
+ endpoint (str): The name of the model on the Triton server.
+ scheme (str): The communication scheme ('http' or 'grpc').
+ """
+ if not endpoint and not scheme: # Parse all args from URL string
+ splits = urlsplit(url)
+ endpoint = splits.path.strip('/').split('/')[0]
+ scheme = splits.scheme
+ url = splits.netloc
+
+ self.endpoint = endpoint
+ self.url = url
+
+ # Choose the Triton client based on the communication scheme
+ if scheme == 'http':
+ import tritonclient.http as client # noqa
+ self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
+ config = self.triton_client.get_model_config(endpoint)
+ else:
+ import tritonclient.grpc as client # noqa
+ self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
+ config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
+
+ self.InferRequestedOutput = client.InferRequestedOutput
+ self.InferInput = client.InferInput
+
+ type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
+ self.input_formats = [x['data_type'] for x in config['input']]
+ self.np_input_formats = [type_map[x] for x in self.input_formats]
+ self.input_names = [x['name'] for x in config['input']]
+ self.output_names = [x['name'] for x in config['output']]
+
+ def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
+ """
+ Call the model with the given inputs.
+
+ Args:
+ *inputs (List[np.ndarray]): Input data to the model.
+
+ Returns:
+ List[np.ndarray]: Model outputs.
+ """
+ infer_inputs = []
+ input_format = inputs[0].dtype
+ for i, x in enumerate(inputs):
+ if x.dtype != self.np_input_formats[i]:
+ x = x.astype(self.np_input_formats[i])
+ infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
+ infer_input.set_data_from_numpy(x)
+ infer_inputs.append(infer_input)
+
+ infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
+ outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
+
+ return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]