mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
Support prediction of list of sources, in-memory dataset and other improvements (#685)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a5410ed79e
commit
0609561549
@ -3,10 +3,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.data.build import load_inference_source
|
||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
@ -40,6 +42,7 @@ def test_predict_dir():
|
||||
|
||||
|
||||
def test_predict_img():
|
||||
|
||||
model = YOLO(MODEL)
|
||||
img = Image.open(str(SOURCE))
|
||||
output = model(source=img, save=True, verbose=True) # PIL
|
||||
@ -54,6 +57,16 @@ def test_predict_img():
|
||||
tens = torch.zeros(320, 640, 3)
|
||||
output = model(tens.numpy())
|
||||
assert len(output) == 1, "predict test failed"
|
||||
# test multiple source
|
||||
imgs = [
|
||||
SOURCE, # filename
|
||||
Path(SOURCE), # Path
|
||||
'https://ultralytics.com/images/zidane.jpg', # URI
|
||||
cv2.imread(str(SOURCE)), # OpenCV
|
||||
Image.open(SOURCE), # PIL
|
||||
np.zeros((320, 640, 3))] # numpy
|
||||
output = model(imgs)
|
||||
assert len(output) == 6, "predict test failed!"
|
||||
|
||||
|
||||
def test_val():
|
||||
@ -129,3 +142,28 @@ def test_workflow():
|
||||
model.val()
|
||||
model.predict(SOURCE)
|
||||
model.export(format="onnx", opset=12) # export a model to ONNX format
|
||||
|
||||
|
||||
def test_predict_callback_and_setup():
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
# results -> List[batch_size]
|
||||
path, _, im0s, _, _ = predictor.batch
|
||||
# print('on_predict_batch_end', im0s[0].shape)
|
||||
bs = [predictor.bs for i in range(0, len(path))]
|
||||
predictor.results = zip(predictor.results, im0s, bs)
|
||||
|
||||
model = YOLO("yolov8n.pt")
|
||||
model.add_callback("on_predict_batch_end", on_predict_batch_end)
|
||||
|
||||
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
|
||||
bs = dataset.bs # access predictor properties
|
||||
results = model.predict(dataset, stream=True) # source already setup
|
||||
for _, (result, im0, bs) in enumerate(results):
|
||||
print('test_callback', im0.shape)
|
||||
print('test_callback', bs)
|
||||
boxes = result.boxes # Boxes object for bbox outputs
|
||||
print(boxes)
|
||||
|
||||
|
||||
test_predict_img()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from .base import BaseDataset
|
||||
from .build import build_classification_dataloader, build_dataloader
|
||||
from .build import build_classification_dataloader, build_dataloader, load_inference_source
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
from .dataset_wrappers import MixAndRectDataset
|
||||
|
@ -2,11 +2,18 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, dataloader, distributed
|
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
|
||||
LoadStreams, SourceTypes, autocast_list)
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils.checks import check_file
|
||||
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.torch_utils import torch_distributed_zero_first
|
||||
from .dataset import ClassificationDataset, YOLODataset
|
||||
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
|
||||
pin_memory=PIN_MEMORY,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator) # or DataLoader(persistent_workers=True)
|
||||
|
||||
|
||||
def check_source(source):
|
||||
webcam, screenshot, from_img, in_memory = False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, tuple(LOADERS)):
|
||||
in_memory = True
|
||||
elif isinstance(source, (list, tuple)):
|
||||
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
||||
from_img = True
|
||||
elif isinstance(source, ((Image.Image, np.ndarray))):
|
||||
from_img = True
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory
|
||||
|
||||
|
||||
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
|
||||
"""
|
||||
TODO: docs
|
||||
"""
|
||||
# source
|
||||
source, webcam, screenshot, from_img, in_memory = check_source(source)
|
||||
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
|
||||
|
||||
# Dataloader
|
||||
if in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
dataset = LoadStreams(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=auto,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||
elif from_img:
|
||||
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||
else:
|
||||
dataset = LoadImages(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=auto,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
setattr(dataset, 'source_type', source_type) # attach source types
|
||||
|
||||
return dataset
|
||||
|
@ -4,14 +4,16 @@ import glob
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceTypes:
|
||||
webcam: bool = False
|
||||
screenshot: bool = False
|
||||
from_img: bool = False
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
@ -63,6 +72,8 @@ class LoadStreams:
|
||||
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
||||
self.auto = auto and self.rect
|
||||
self.transforms = transforms # optional
|
||||
self.bs = self.__len__()
|
||||
|
||||
if not self.rect:
|
||||
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
||||
|
||||
@ -128,6 +139,7 @@ class LoadScreenshots:
|
||||
self.mode = 'stream'
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
self.bs = 1
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
@ -185,6 +197,7 @@ class LoadImages:
|
||||
self.auto = auto
|
||||
self.transforms = transforms # optional
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.bs = 1
|
||||
if any(videos):
|
||||
self.orientation = None # rotation degrees
|
||||
self._new_video(videos[0]) # new video
|
||||
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
|
||||
self.mode = 'image'
|
||||
# generate fake paths
|
||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
||||
self.bs = 1
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
|
||||
return self
|
||||
|
||||
|
||||
def autocast_list(source):
|
||||
"""
|
||||
Merges a list of source of different types into a list of numpy arrays or PIL images
|
||||
"""
|
||||
files = []
|
||||
for _, im in enumerate(source):
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||
files.append(im)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
||||
dataset = LoadPilAndNumpy(im0=img)
|
||||
|
@ -233,6 +233,13 @@ class YOLO:
|
||||
"""
|
||||
return self.model.names
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""
|
||||
Returns transform of the loaded model.
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
"""
|
||||
|
@ -30,13 +30,13 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.data import load_inference_source
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
@ -76,6 +76,8 @@ class BasePredictor:
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
if self.args.show:
|
||||
self.args.show = check_imshow(warn=True)
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
@ -88,6 +90,7 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@ -103,53 +106,6 @@ class BasePredictor:
|
||||
def postprocess(self, preds, img, orig_img, classes=None):
|
||||
return preds
|
||||
|
||||
def setup_source(self, source=None):
|
||||
if not self.model:
|
||||
raise Exception("setup model before setting up source!")
|
||||
# source
|
||||
source, webcam, screenshot, from_img = self.check_source(source)
|
||||
# model
|
||||
stride, pt = self.model.stride, self.model.pt
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
||||
|
||||
# Dataloader
|
||||
bs = 1 # batch_size
|
||||
if webcam:
|
||||
self.args.show = check_imshow(warn=True)
|
||||
self.dataset = LoadStreams(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
vid_stride=self.args.vid_stride)
|
||||
bs = len(self.dataset)
|
||||
elif screenshot:
|
||||
self.dataset = LoadScreenshots(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(self.model.model, 'transforms', None))
|
||||
elif from_img:
|
||||
self.dataset = LoadPilAndNumpy(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(self.model.model, 'transforms', None))
|
||||
else:
|
||||
self.dataset = LoadImages(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
vid_stride=self.args.vid_stride)
|
||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
||||
|
||||
self.webcam = webcam
|
||||
self.screenshot = screenshot
|
||||
self.from_img = from_img
|
||||
self.imgsz = imgsz
|
||||
self.bs = bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
if stream:
|
||||
@ -163,14 +119,29 @@ class BasePredictor:
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
if not self.model:
|
||||
raise Exception("Model not initialized!")
|
||||
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.dataset = load_inference_source(source=source,
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
imgsz=self.imgsz,
|
||||
vid_stride=self.args.vid_stride,
|
||||
stride=self.model.stride,
|
||||
auto=self.model.pt)
|
||||
self.source_type = self.dataset.source_type
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
def stream_inference(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
|
||||
# setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
# setup source. Run every time predict is called
|
||||
self.setup_source(source)
|
||||
# setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
@ -198,7 +169,7 @@ class BasePredictor:
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||
for i in range(len(im)):
|
||||
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
||||
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
@ -237,21 +208,6 @@ class BasePredictor:
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
|
||||
def check_source(self, source):
|
||||
source = source if source is not None else self.args.source
|
||||
webcam, screenshot, from_img = False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
else:
|
||||
from_img = True
|
||||
return source, webcam, screenshot, from_img
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
|
@ -33,7 +33,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
|
@ -42,7 +42,7 @@ class DetectionPredictor(BasePredictor):
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
|
@ -43,7 +43,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user