mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.164
new StreamLoader stream_buffer
argument (#4596)
Co-authored-by: jgoo9410 <jjoohhnnggooddwwiinn@gmail.com> Co-authored-by: John Goodwin <johnf4g@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
bd96c0846b
commit
1121ef2409
4
.github/workflows/docker.yaml
vendored
4
.github/workflows/docker.yaml
vendored
@ -118,7 +118,9 @@ jobs:
|
||||
run: |
|
||||
docker push ultralytics/ultralytics:${{ matrix.tags }}
|
||||
if [[ "${{ matrix.tags }}" == "latest" ]]; then
|
||||
t=ultralytics/ultralytics:latest-runner && sudo docker build -f docker/Dockerfile-runner -t $t . && sudo docker push $t
|
||||
t=ultralytics/ultralytics:latest-runner
|
||||
docker build -f docker/Dockerfile-runner -t $t .
|
||||
docker push $t
|
||||
fi
|
||||
|
||||
- name: Notify on failure
|
||||
|
@ -54,21 +54,21 @@ YOLOv8 can process different types of input sources for inference, as shown in t
|
||||
|
||||
Use `stream=True` for processing long videos or large datasets to efficiently manage memory. When `stream=False`, the results for all frames or data points are stored in memory, which can quickly add up and cause out-of-memory errors for large inputs. In contrast, `stream=True` utilizes a generator, which only keeps the results of the current frame or data point in memory, significantly reducing memory consumption and preventing out-of-memory issues.
|
||||
|
||||
| Source | Argument | Type | Notes |
|
||||
|---------------|--------------------------------------------|-----------------|---------------------------------------------------------------------------------------------|
|
||||
| image | `'image.jpg'` | `str` or `Path` | Single image file. |
|
||||
| URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. |
|
||||
| screenshot | `'screen'` | `str` | Capture a screenshot. |
|
||||
| PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. |
|
||||
| OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
|
||||
| numpy | `np.zeros((640,1280,3))` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
|
||||
| torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW format with RGB channels `float32 (0.0-1.0)`. |
|
||||
| CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. |
|
||||
| video ✅ | `'video.mp4'` | `str` or `Path` | Video file in formats like MP4, AVI, etc. |
|
||||
| directory ✅ | `'path/'` | `str` or `Path` | Path to a directory containing images or videos. |
|
||||
| glob ✅ | `'path/*.jpg'` | `str` | Glob pattern to match multiple files. Use the `*` character as a wildcard. |
|
||||
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | URL to a YouTube video. |
|
||||
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | URL for streaming protocols such as RTSP, RTMP, or an IP address. |
|
||||
| Source | Argument | Type | Notes |
|
||||
|----------------|--------------------------------------------|-----------------|---------------------------------------------------------------------------------------------|
|
||||
| image | `'image.jpg'` | `str` or `Path` | Single image file. |
|
||||
| URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. |
|
||||
| screenshot | `'screen'` | `str` | Capture a screenshot. |
|
||||
| PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. |
|
||||
| OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
|
||||
| numpy | `np.zeros((640,1280,3))` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
|
||||
| torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW format with RGB channels `float32 (0.0-1.0)`. |
|
||||
| CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. |
|
||||
| video ✅ | `'video.mp4'` | `str` or `Path` | Video file in formats like MP4, AVI, etc. |
|
||||
| directory ✅ | `'path/'` | `str` or `Path` | Path to a directory containing images or videos. |
|
||||
| glob ✅ | `'path/*.jpg'` | `str` | Glob pattern to match multiple files. Use the `*` character as a wildcard. |
|
||||
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | URL to a YouTube video. |
|
||||
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | URL for streaming protocols such as RTSP, RTMP, or an IP address. |
|
||||
| multi-stream ✅ | `'list.streams'` | `str` or `Path` | `*.streams` text file with one stream URL per row, i.e. 8 streams will run at batch-size 8. |
|
||||
|
||||
Below are code examples for using each source type:
|
||||
@ -299,30 +299,31 @@ Below are code examples for using each source type:
|
||||
|
||||
All supported arguments:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|----------------|----------------|------------------------|--------------------------------------------------------------------------------|
|
||||
| `source` | `str` | `'ultralytics/assets'` | source directory for images or videos |
|
||||
| `conf` | `float` | `0.25` | object confidence threshold for detection |
|
||||
| `iou` | `float` | `0.7` | intersection over union (IoU) threshold for NMS |
|
||||
| `imgsz` | `int or tuple` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) |
|
||||
| `half` | `bool` | `False` | use half precision (FP16) |
|
||||
| `device` | `None or str` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||
| `show` | `bool` | `False` | show results if possible |
|
||||
| `save` | `bool` | `False` | save images with results |
|
||||
| `save_txt` | `bool` | `False` | save results as .txt file |
|
||||
| `save_conf` | `bool` | `False` | save results with confidence scores |
|
||||
| `save_crop` | `bool` | `False` | save cropped images with results |
|
||||
| `hide_labels` | `bool` | `False` | hide labels |
|
||||
| `hide_conf` | `bool` | `False` | hide confidence scores |
|
||||
| `max_det` | `int` | `300` | maximum number of detections per image |
|
||||
| `vid_stride` | `bool` | `False` | video frame-rate stride |
|
||||
| `line_width` | `None or int` | `None` | The line width of the bounding boxes. If None, it is scaled to the image size. |
|
||||
| `visualize` | `bool` | `False` | visualize model features |
|
||||
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
|
||||
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `boxes` | `bool` | `True` | Show boxes in segmentation predictions |
|
||||
| Name | Type | Default | Description |
|
||||
|-----------------|----------------|------------------------|--------------------------------------------------------------------------------|
|
||||
| `source` | `str` | `'ultralytics/assets'` | source directory for images or videos |
|
||||
| `conf` | `float` | `0.25` | object confidence threshold for detection |
|
||||
| `iou` | `float` | `0.7` | intersection over union (IoU) threshold for NMS |
|
||||
| `imgsz` | `int or tuple` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) |
|
||||
| `half` | `bool` | `False` | use half precision (FP16) |
|
||||
| `device` | `None or str` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||
| `show` | `bool` | `False` | show results if possible |
|
||||
| `save` | `bool` | `False` | save images with results |
|
||||
| `save_txt` | `bool` | `False` | save results as .txt file |
|
||||
| `save_conf` | `bool` | `False` | save results with confidence scores |
|
||||
| `save_crop` | `bool` | `False` | save cropped images with results |
|
||||
| `hide_labels` | `bool` | `False` | hide labels |
|
||||
| `hide_conf` | `bool` | `False` | hide confidence scores |
|
||||
| `max_det` | `int` | `300` | maximum number of detections per image |
|
||||
| `vid_stride` | `bool` | `False` | video frame-rate stride |
|
||||
| `stream_buffer` | `bool` | `False` | buffer all streaming frames (True) or return the most recent frame (False) |
|
||||
| `line_width` | `None or int` | `None` | The line width of the bounding boxes. If None, it is scaled to the image size. |
|
||||
| `visualize` | `bool` | `False` | visualize model features |
|
||||
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
|
||||
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `boxes` | `bool` | `True` | Show boxes in segmentation predictions |
|
||||
|
||||
## Image and Video Formats
|
||||
|
||||
|
@ -133,29 +133,30 @@ The training settings for YOLO models encompass various hyperparameters and conf
|
||||
|
||||
The prediction settings for YOLO models encompass a range of hyperparameters and configurations that influence the model's performance, speed, and accuracy during inference on new data. Careful tuning and experimentation with these settings are essential to achieve optimal performance for a specific task. Key settings include the confidence threshold, Non-Maximum Suppression (NMS) threshold, and the number of classes considered. Additional factors affecting the prediction process are input data size and format, the presence of supplementary features such as masks or multiple labels per box, and the particular task the model is employed for.
|
||||
|
||||
| Key | Value | Description |
|
||||
|----------------|------------------------|--------------------------------------------------------------------------------|
|
||||
| `source` | `'ultralytics/assets'` | source directory for images or videos |
|
||||
| `conf` | `0.25` | object confidence threshold for detection |
|
||||
| `iou` | `0.7` | intersection over union (IoU) threshold for NMS |
|
||||
| `half` | `False` | use half precision (FP16) |
|
||||
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||
| `show` | `False` | show results if possible |
|
||||
| `save` | `False` | save images with results |
|
||||
| `save_txt` | `False` | save results as .txt file |
|
||||
| `save_conf` | `False` | save results with confidence scores |
|
||||
| `save_crop` | `False` | save cropped images with results |
|
||||
| `show_labels` | `True` | show object labels in plots |
|
||||
| `show_conf` | `True` | show object confidence scores in plots |
|
||||
| `max_det` | `300` | maximum number of detections per image |
|
||||
| `vid_stride` | `False` | video frame-rate stride |
|
||||
| `line_width` | `None` | The line width of the bounding boxes. If None, it is scaled to the image size. |
|
||||
| `visualize` | `False` | visualize model features |
|
||||
| `augment` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `False` | class-agnostic NMS |
|
||||
| `retina_masks` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `boxes` | `True` | Show boxes in segmentation predictions |
|
||||
| Key | Value | Description |
|
||||
|-----------------|------------------------|--------------------------------------------------------------------------------|
|
||||
| `source` | `'ultralytics/assets'` | source directory for images or videos |
|
||||
| `conf` | `0.25` | object confidence threshold for detection |
|
||||
| `iou` | `0.7` | intersection over union (IoU) threshold for NMS |
|
||||
| `half` | `False` | use half precision (FP16) |
|
||||
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||
| `show` | `False` | show results if possible |
|
||||
| `save` | `False` | save images with results |
|
||||
| `save_txt` | `False` | save results as .txt file |
|
||||
| `save_conf` | `False` | save results with confidence scores |
|
||||
| `save_crop` | `False` | save cropped images with results |
|
||||
| `show_labels` | `True` | show object labels in plots |
|
||||
| `show_conf` | `True` | show object confidence scores in plots |
|
||||
| `max_det` | `300` | maximum number of detections per image |
|
||||
| `vid_stride` | `False` | video frame-rate stride |
|
||||
| `stream_buffer` | `bool` | buffer all streaming frames (True) or return the most recent frame (False) |
|
||||
| `line_width` | `None` | The line width of the bounding boxes. If None, it is scaled to the image size. |
|
||||
| `visualize` | `False` | visualize model features |
|
||||
| `augment` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `False` | class-agnostic NMS |
|
||||
| `retina_masks` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `boxes` | `True` | Show boxes in segmentation predictions |
|
||||
|
||||
[Predict Guide](../modes/predict.md){ .md-button .md-button--primary}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.163'
|
||||
__version__ = '8.0.164'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
@ -60,6 +60,7 @@ save_crop: False # (bool) save cropped images with results
|
||||
show_labels: True # (bool) show object labels in plots
|
||||
show_conf: True # (bool) show object confidence scores in plots
|
||||
vid_stride: 1 # (int) video frame-rate stride
|
||||
stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
|
||||
line_width: # (int, optional) line width of the bounding boxes, auto if missing
|
||||
visualize: False # (bool) visualize model features
|
||||
augment: False # (bool) apply image augmentation to prediction sources
|
||||
|
@ -135,7 +135,7 @@ def check_source(source):
|
||||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
||||
def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
||||
def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False):
|
||||
"""
|
||||
Loads an inference source for object detection and applies necessary transformations.
|
||||
|
||||
@ -143,6 +143,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
||||
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
||||
imgsz (int, optional): The size of the image for inference. Default is 640.
|
||||
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
||||
stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): A dataset object for the specified input source.
|
||||
@ -156,7 +157,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
||||
elif in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz)
|
||||
elif from_img:
|
||||
|
@ -31,9 +31,10 @@ class SourceTypes:
|
||||
class LoadStreams:
|
||||
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
||||
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.stream_buffer = stream_buffer # buffer input streams
|
||||
self.running = True # running flag for Thread
|
||||
self.mode = 'stream'
|
||||
self.imgsz = imgsz
|
||||
@ -81,7 +82,7 @@ class LoadStreams:
|
||||
n, f = 0, self.frames[i] # frame number, frame array
|
||||
while self.running and cap.isOpened() and n < (f - 1):
|
||||
# Only read a new frame if the buffer is empty
|
||||
if not self.imgs[i]:
|
||||
if not self.imgs[i] or not self.stream_buffer:
|
||||
n += 1
|
||||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||
if n % self.vid_stride == 0:
|
||||
@ -124,7 +125,16 @@ class LoadStreams:
|
||||
time.sleep(1 / min(self.fps))
|
||||
|
||||
# Get and remove the next frame from imgs buffer
|
||||
return self.sources, [x.pop(0) for x in self.imgs], None, ''
|
||||
if self.stream_buffer:
|
||||
images = [x.pop(0) for x in self.imgs]
|
||||
else:
|
||||
# Get the latest frame, and clear the rest from the imgs buffer
|
||||
images = []
|
||||
for x in self.imgs:
|
||||
images.append(x.pop(-1) if x else None)
|
||||
x.clear()
|
||||
|
||||
return self.sources, images, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the sources object."""
|
||||
|
@ -209,7 +209,10 @@ class BasePredictor:
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
||||
self.imgsz[0])) if self.args.task == 'classify' else None
|
||||
self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
|
||||
self.dataset = load_inference_source(source=source,
|
||||
imgsz=self.imgsz,
|
||||
vid_stride=self.args.vid_stride,
|
||||
stream_buffer=self.args.stream_buffer)
|
||||
self.source_type = self.dataset.source_type
|
||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
||||
len(self.dataset) > 1000 or # images
|
||||
|
Loading…
x
Reference in New Issue
Block a user