mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.208
automatic thread-safe inference (#6185)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: PIW <56834479+parkilwoo@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e7bd159a44
commit
795b95bdcb
@ -16,6 +16,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
|
||||
|
||||
* [YOLO Common Issues](yolo-common-issues.md) ⭐ RECOMMENDED: Practical solutions and troubleshooting tips to the most frequently encountered issues when working with Ultralytics YOLO models.
|
||||
* [YOLO Performance Metrics](yolo-performance-metrics.md) ⭐ ESSENTIAL: Understand the key metrics like mAP, IoU, and F1 score used to evaluate the performance of your YOLO models. Includes practical examples and tips on how to improve detection accuracy and speed.
|
||||
* [Model Deployment Options](model-deployment-options.md): Overview of YOLO model deployment formats like ONNX, OpenVINO, and TensorRT, with pros and cons for each to inform your deployment strategy.
|
||||
* [K-Fold Cross Validation](kfold-cross-validation.md) 🚀 NEW: Learn how to improve model generalization using K-Fold cross-validation technique.
|
||||
* [Hyperparameter Tuning](hyperparameter-tuning.md) 🚀 NEW: Discover how to optimize your YOLO models by fine-tuning hyperparameters using the Tuner class and genetic evolution algorithms.
|
||||
* [SAHI Tiled Inference](sahi-tiled-inference.md) 🚀 NEW: Comprehensive guide on leveraging SAHI's sliced inference capabilities with YOLOv8 for object detection in high-resolution images.
|
||||
@ -24,6 +25,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
|
||||
* [Docker Quickstart](docker-quickstart.md) 🚀 NEW: Complete guide to setting up and using Ultralytics YOLO models with [Docker](https://hub.docker.com/r/ultralytics/ultralytics). Learn how to install Docker, manage GPU support, and run YOLO models in isolated containers for consistent development and deployment.
|
||||
* [Raspberry Pi](raspberry-pi.md) 🚀 NEW: Quickstart tutorial to run YOLO models to the latest Raspberry Pi hardware.
|
||||
* [Triton Inference Server Integration](triton-inference-server.md) 🚀 NEW: Dive into the integration of Ultralytics YOLOv8 with NVIDIA's Triton Inference Server for scalable and efficient deep learning inference deployments.
|
||||
* [YOLO Thread-Safe Inference](yolo-thread-safe-inference.md) 🚀 NEW: Guidelines for performing inference with YOLO models in a thread-safe manner. Learn the importance of thread safety and best practices to prevent race conditions and ensure consistent predictions.
|
||||
|
||||
## Contribute to Our Guides
|
||||
|
||||
|
108
docs/guides/yolo-thread-safe-inference.md
Normal file
108
docs/guides/yolo-thread-safe-inference.md
Normal file
@ -0,0 +1,108 @@
|
||||
---
|
||||
comments: true
|
||||
description: This guide provides best practices for performing thread-safe inference with YOLO models, ensuring reliable and concurrent predictions in multi-threaded applications.
|
||||
keywords: thread-safe, YOLO inference, multi-threading, concurrent predictions, YOLO models, Ultralytics, Python threading, safe YOLO usage, AI concurrency
|
||||
---
|
||||
|
||||
# Thread-Safe Inference with YOLO Models
|
||||
|
||||
Running YOLO models in a multi-threaded environment requires careful consideration to ensure thread safety. Python's `threading` module allows you to run several threads concurrently, but when it comes to using YOLO models across these threads, there are important safety issues to be aware of. This page will guide you through creating thread-safe YOLO model inference.
|
||||
|
||||
## Understanding Python Threading
|
||||
|
||||
Python threads are a form of parallelism that allow your program to run multiple operations at once. However, Python's Global Interpreter Lock (GIL) means that only one thread can execute Python bytecode at a time.
|
||||
|
||||
<p align="center">
|
||||
<img width="800" src="https://user-images.githubusercontent.com/26833433/281418476-7f478570-fd77-4a40-bf3d-74b4db4d668c.png" alt="Single vs Multi-Thread Examples">
|
||||
</p>
|
||||
|
||||
While this sounds like a limitation, threads can still provide concurrency, especially for I/O-bound operations or when using operations that release the GIL, like those performed by YOLO's underlying C libraries.
|
||||
|
||||
## The Danger of Shared Model Instances
|
||||
|
||||
Instantiating a YOLO model outside your threads and sharing this instance across multiple threads can lead to race conditions, where the internal state of the model is inconsistently modified due to concurrent accesses. This is particularly problematic when the model or its components hold state that is not designed to be thread-safe.
|
||||
|
||||
### Non-Thread-Safe Example: Single Model Instance
|
||||
|
||||
When using threads in Python, it's important to recognize patterns that can lead to concurrency issues. Here is what you should avoid: sharing a single YOLO model instance across multiple threads.
|
||||
|
||||
```python
|
||||
# Unsafe: Sharing a single model instance across threads
|
||||
from ultralytics import YOLO
|
||||
from threading import Thread
|
||||
|
||||
# Instantiate the model outside the thread
|
||||
shared_model = YOLO("yolov8n.pt")
|
||||
|
||||
|
||||
def predict(image_path):
|
||||
results = shared_model.predict(image_path)
|
||||
# Process results
|
||||
|
||||
|
||||
# Starting threads that share the same model instance
|
||||
Thread(target=predict, args=("image1.jpg",)).start()
|
||||
Thread(target=predict, args=("image2.jpg",)).start()
|
||||
```
|
||||
|
||||
In the example above, the `shared_model` is used by multiple threads, which can lead to unpredictable results because `predict` could be executed simultaneously by multiple threads.
|
||||
|
||||
### Non-Thread-Safe Example: Multiple Model Instances
|
||||
|
||||
Similarly, here is an unsafe pattern with multiple YOLO model instances:
|
||||
|
||||
```python
|
||||
# Unsafe: Sharing multiple model instances across threads can still lead to issues
|
||||
from ultralytics import YOLO
|
||||
from threading import Thread
|
||||
|
||||
# Instantiate multiple models outside the thread
|
||||
shared_model_1 = YOLO("yolov8n_1.pt")
|
||||
shared_model_2 = YOLO("yolov8n_2.pt")
|
||||
|
||||
|
||||
def predict(model, image_path):
|
||||
results = model.predict(image_path)
|
||||
# Process results
|
||||
|
||||
|
||||
# Starting threads with individual model instances
|
||||
Thread(target=predict, args=(shared_model_1, "image1.jpg")).start()
|
||||
Thread(target=predict, args=(shared_model_2, "image2.jpg")).start()
|
||||
```
|
||||
|
||||
Even though there are two separate model instances, the risk of concurrency issues still exists. If the internal implementation of `YOLO` is not thread-safe, using separate instances might not prevent race conditions, especially if these instances share any underlying resources or states that are not thread-local.
|
||||
|
||||
## Thread-Safe Inference
|
||||
|
||||
To perform thread-safe inference, you should instantiate a separate YOLO model within each thread. This ensures that each thread has its own isolated model instance, eliminating the risk of race conditions.
|
||||
|
||||
### Thread-Safe Example
|
||||
|
||||
Here's how to instantiate a YOLO model inside each thread for safe parallel inference:
|
||||
|
||||
```python
|
||||
# Safe: Instantiating a single model inside each thread
|
||||
from ultralytics import YOLO
|
||||
from threading import Thread
|
||||
|
||||
|
||||
def thread_safe_predict(image_path):
|
||||
# Instantiate a new model inside the thread
|
||||
local_model = YOLO("yolov8n.pt")
|
||||
results = local_model.predict(image_path)
|
||||
# Process results
|
||||
|
||||
|
||||
# Starting threads that each have their own model instance
|
||||
Thread(target=thread_safe_predict, args=("image1.jpg",)).start()
|
||||
Thread(target=thread_safe_predict, args=("image2.jpg",)).start()
|
||||
```
|
||||
|
||||
In this example, each thread creates its own `YOLO` instance. This prevents any thread from interfering with the model state of another, thus ensuring that each thread performs inference safely and without unexpected interactions with the other threads.
|
||||
|
||||
## Conclusion
|
||||
|
||||
When using YOLO models with Python's `threading`, always instantiate your models within the thread that will use them to ensure thread safety. This practice avoids race conditions and makes sure that your inference tasks run reliably.
|
||||
|
||||
For more advanced scenarios and to further optimize your multi-threaded inference performance, consider using process-based parallelism with `multiprocessing` or leveraging a task queue with dedicated worker processes.
|
@ -6,6 +6,29 @@ keywords: Ultralytics, Android App, real-time object detection, YOLO models, Ten
|
||||
|
||||
# Ultralytics Android App: Real-time Object Detection with YOLO Models
|
||||
|
||||
<a href="https://bit.ly/ultralytics_hub" target="_blank">
|
||||
<img width="100%" src="https://user-images.githubusercontent.com/26833433/281124469-6b3b0945-dbb1-44c8-80a9-ef6bc778b299.jpg" alt="Ultralytics HUB preview image"></a>
|
||||
<br>
|
||||
<div align="center">
|
||||
<a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://youtube.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://play.google.com/store/apps/details?id=com.ultralytics.ultralytics_app" style="text-decoration:none;">
|
||||
<img src="https://raw.githubusercontent.com/ultralytics/assets/master/app/google-play.svg" width="15%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
The Ultralytics Android App is a powerful tool that allows you to run YOLO models directly on your Android device for real-time object detection. This app utilizes TensorFlow Lite for model optimization and various hardware delegates for acceleration, enabling fast and efficient object detection.
|
||||
|
||||
## Quantization and Acceleration
|
||||
|
@ -10,29 +10,25 @@ keywords: Ultralytics, HUB App, YOLOv5, YOLOv8, mobile AI, real-time object dete
|
||||
<img width="100%" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png" alt="Ultralytics HUB preview image"></a>
|
||||
<br>
|
||||
<div align="center">
|
||||
<a href="https://github.com/ultralytics" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="2%" alt="" /></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
|
||||
<a href="https://www.linkedin.com/company/ultralytics/" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="2%" alt="" /></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
|
||||
<a href="https://twitter.com/ultralytics" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="2%" alt="" /></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
|
||||
<a href="https://youtube.com/ultralytics" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="2%" alt="" /></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
|
||||
<a href="https://www.tiktok.com/@ultralytics" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="2%" alt="" /></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
|
||||
<a href="https://www.instagram.com/ultralytics/" style="text-decoration:none;">
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="2%" alt="" /></a>
|
||||
<a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://youtube.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://play.google.com/store/apps/details?id=com.ultralytics.ultralytics_app" style="text-decoration:none;">
|
||||
<img src="https://raw.githubusercontent.com/ultralytics/assets/master/app/google-play.svg" width="15%" alt="" /></a>
|
||||
<a href="https://apps.apple.com/xk/app/ultralytics/id1583935240" style="text-decoration:none;">
|
||||
<img src="https://raw.githubusercontent.com/ultralytics/assets/master/app/app-store.svg" width="15%" alt="" /></a>
|
||||
<a href="https://play.google.com/store/apps/details?id=com.ultralytics.ultralytics_app" style="text-decoration:none;">
|
||||
<img src="https://raw.githubusercontent.com/ultralytics/assets/master/app/google-play.svg" width="15%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
Welcome to the Ultralytics HUB App! We are excited to introduce this powerful mobile app that allows you to run YOLOv5 and YOLOv8 models directly on your [iOS](https://apps.apple.com/xk/app/ultralytics/id1583935240) and [Android](https://play.google.com/store/apps/details?id=com.ultralytics.ultralytics_app) devices. With the HUB App, you can utilize hardware acceleration features like Apple's Neural Engine (ANE) or Android GPU and Neural Network API (NNAPI) delegates to achieve impressive performance on your mobile device.
|
||||
|
@ -6,6 +6,29 @@ keywords: Ultralytics, iOS app, object detection, YOLO models, real time, Apple
|
||||
|
||||
# Ultralytics iOS App: Real-time Object Detection with YOLO Models
|
||||
|
||||
<a href="https://bit.ly/ultralytics_hub" target="_blank">
|
||||
<img width="100%" src="https://user-images.githubusercontent.com/26833433/281124469-6b3b0945-dbb1-44c8-80a9-ef6bc778b299.jpg" alt="Ultralytics HUB preview image"></a>
|
||||
<br>
|
||||
<div align="center">
|
||||
<a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://youtube.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://apps.apple.com/xk/app/ultralytics/id1583935240" style="text-decoration:none;">
|
||||
<img src="https://raw.githubusercontent.com/ultralytics/assets/master/app/app-store.svg" width="15%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
The Ultralytics iOS App is a powerful tool that allows you to run YOLO models directly on your iPhone or iPad for real-time object detection. This app utilizes the Apple Neural Engine and Core ML for model optimization and acceleration, enabling fast and efficient object detection.
|
||||
|
||||
## Quantization and Acceleration
|
||||
|
@ -9,14 +9,27 @@ keywords: Ultralytics HUB, YOLOv5, YOLOv8, model training, model deployment, pre
|
||||
<a href="https://bit.ly/ultralytics_hub" target="_blank">
|
||||
<img width="100%" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png" alt="Ultralytics HUB preview image"></a>
|
||||
<br>
|
||||
<br>
|
||||
<div align="center">
|
||||
<a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://youtube.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://github.com/ultralytics/hub/actions/workflows/ci.yaml">
|
||||
<img src="https://github.com/ultralytics/hub/actions/workflows/ci.yaml/badge.svg" alt="CI CPU"></a>
|
||||
<a href="https://colab.research.google.com/github/ultralytics/hub/blob/master/hub.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
👋 Hello from the [Ultralytics](https://ultralytics.com/) Team! We've been working hard these last few months to launch [Ultralytics HUB](https://bit.ly/ultralytics_hub), a new web tool for training and deploying all your YOLOv5 and YOLOv8 🚀 models from one spot!
|
||||
|
||||
|
@ -9,6 +9,21 @@ keywords: Ultralytics, YOLOv8, object detection, image segmentation, machine lea
|
||||
<a href="https://yolovision.ultralytics.com" target="_blank">
|
||||
<img width="1024" src="https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png" alt="Ultralytics YOLO banner"></a>
|
||||
</p>
|
||||
<a href="https://github.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="3%" alt="Ultralytics GitHub"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.linkedin.com/company/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="3%" alt="Ultralytics LinkedIn"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://twitter.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="3%" alt="Ultralytics Twitter"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://youtube.com/ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="3%" alt="Ultralytics YouTube"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.tiktok.com/@ultralytics"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-tiktok.png" width="3%" alt="Ultralytics TikTok"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://www.instagram.com/ultralytics/"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="3%" alt="Ultralytics Instagram"></a>
|
||||
<img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="3%">
|
||||
<a href="https://ultralytics.com/discord"><img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-discord.png" width="3%" alt="Ultralytics Discord"></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://github.com/ultralytics/ultralytics/actions/workflows/ci.yaml"><img src="https://github.com/ultralytics/ultralytics/actions/workflows/ci.yaml/badge.svg" alt="Ultralytics CI"></a>
|
||||
<a href="https://codecov.io/github/ultralytics/ultralytics"><img src="https://codecov.io/github/ultralytics/ultralytics/branch/main/graph/badge.svg?token=HHW7IIVFVY" alt="Ultralytics Code Coverage"></a>
|
||||
<a href="https://zenodo.org/badge/latestdoi/264818686"><img src="https://zenodo.org/badge/264818686.svg" alt="YOLOv8 Citation"></a>
|
||||
|
@ -641,6 +641,33 @@ You can use the `plot()` method of a `Result` objects to visualize predictions.
|
||||
| `masks` | `bool` | Whether to plot the masks. | `True` |
|
||||
| `probs` | `bool` | Whether to plot classification probability | `True` |
|
||||
|
||||
## Thread-Safe Inference
|
||||
|
||||
Ensuring thread safety during inference is crucial when you are running multiple YOLO models in parallel across different threads. Thread-safe inference guarantees that each thread's predictions are isolated and do not interfere with one another, avoiding race conditions and ensuring consistent and reliable outputs.
|
||||
|
||||
When using YOLO models in a multi-threaded application, it's important to instantiate separate model objects for each thread or employ thread-local storage to prevent conflicts:
|
||||
|
||||
!!! example "Thread-Safe Inference"
|
||||
|
||||
Instantiate a single model inside each thread for thread-safe inference:
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
from threading import Thread
|
||||
|
||||
def thread_safe_predict(image_path):
|
||||
# Instantiate a new model inside the thread
|
||||
local_model = YOLO("yolov8n.pt")
|
||||
results = local_model.predict(image_path)
|
||||
# Process results
|
||||
|
||||
|
||||
# Starting threads that each have their own model instance
|
||||
Thread(target=thread_safe_predict, args=("image1.jpg",)).start()
|
||||
Thread(target=thread_safe_predict, args=("image2.jpg",)).start()
|
||||
```
|
||||
|
||||
For an in-depth look at thread-safe inference with YOLO models and step-by-step instructions, please refer to our [YOLO Thread-Safe Inference Guide](../guides/yolo-thread-safe-inference.md). This guide will provide you with all the necessary information to avoid common pitfalls and ensure that your multi-threaded inference runs smoothly.
|
||||
|
||||
## Streaming Source `for`-loop
|
||||
|
||||
Here's a Python script using OpenCV (`cv2`) and YOLOv8 to run inference on video frames. This script assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`).
|
||||
|
@ -118,17 +118,6 @@ Ultralytics provides various installation methods including pip, conda, and Dock
|
||||
|
||||
See the `ultralytics` [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) file for a list of dependencies. Note that all examples above install all required dependencies.
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<iframe width="720" height="405" src="https://www.youtube.com/embed/MWq1UxqTClU?si=nHAW-lYDzrz68jR0"
|
||||
title="YouTube video player" frameborder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
<br>
|
||||
<strong>Watch:</strong> Ultralytics YOLO for Object Detection: Quickstart Guide for Installation and Setup.
|
||||
</p>
|
||||
|
||||
!!! tip "Tip"
|
||||
|
||||
PyTorch requirements vary by operating system and CUDA requirements, so it's recommended to install PyTorch first following instructions at [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally).
|
||||
|
@ -8,6 +8,17 @@ keywords: Ultralytics, YOLO, CLI, train, validation, prediction, command line in
|
||||
|
||||
The YOLO command line interface (CLI) allows for simple single-line commands without the need for a Python environment. CLI requires no customization or Python code. You can simply run all tasks from the terminal with the `yolo` command.
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<iframe width="720" height="405" src="https://www.youtube.com/embed/GsXGnb-A4Kc"
|
||||
title="YouTube video player" frameborder="0"
|
||||
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
<br>
|
||||
<strong>Watch:</strong> Mastering Ultralytics YOLOv8: CLI & Python Usage and Live Inference
|
||||
</p>
|
||||
|
||||
!!! example
|
||||
|
||||
=== "Syntax"
|
||||
|
@ -215,7 +215,9 @@ nav:
|
||||
- Guides:
|
||||
- guides/index.md
|
||||
- YOLO Common Issues: guides/yolo-common-issues.md
|
||||
- Performance Metrics: guides/yolo-performance-metrics.md
|
||||
- YOLO Performance Metrics: guides/yolo-performance-metrics.md
|
||||
- YOLO Thread-Safe Inference: guides/yolo-thread-safe-inference.md
|
||||
- Model Deployment Options: guides/model-deployment-options.md
|
||||
- K-Fold Cross Validation: guides/kfold-cross-validation.md
|
||||
- Hyperparameter Tuning: guides/hyperparameter-tuning.md
|
||||
- SAHI Tiled Inference: guides/sahi-tiled-inference.md
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.207'
|
||||
__version__ = '8.0.208'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
@ -300,7 +300,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||
|
||||
# Parse YAML
|
||||
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
||||
val, s = (data.get(x) for x in ('val', 'download'))
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
|
@ -28,6 +28,7 @@ Usage - formats:
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -106,6 +107,7 @@ class BasePredictor:
|
||||
self.transforms = None
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
self.txt_path = None
|
||||
self._lock = threading.Lock() # for automatic thread-safe inference
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, im):
|
||||
@ -231,64 +233,66 @@ class BasePredictor:
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
with self._lock: # for thread-safe inference
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
if self.args.save or self.args.save_txt:
|
||||
self.results[i].save_dir = self.save_dir.__str__()
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
if self.args.save or self.args.save_txt:
|
||||
self.results[i].save_dir = self.save_dir.__str__()
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
||||
|
||||
# Release assets
|
||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||
|
Loading…
x
Reference in New Issue
Block a user