mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Support prediction of list of sources, in-memory dataset and other improvements (#685)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a5410ed79e
commit
0609561549
@ -3,10 +3,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
from ultralytics.yolo.data.build import load_inference_source
|
||||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||||
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||||
@ -40,6 +42,7 @@ def test_predict_dir():
|
|||||||
|
|
||||||
|
|
||||||
def test_predict_img():
|
def test_predict_img():
|
||||||
|
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
img = Image.open(str(SOURCE))
|
img = Image.open(str(SOURCE))
|
||||||
output = model(source=img, save=True, verbose=True) # PIL
|
output = model(source=img, save=True, verbose=True) # PIL
|
||||||
@ -54,6 +57,16 @@ def test_predict_img():
|
|||||||
tens = torch.zeros(320, 640, 3)
|
tens = torch.zeros(320, 640, 3)
|
||||||
output = model(tens.numpy())
|
output = model(tens.numpy())
|
||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, "predict test failed"
|
||||||
|
# test multiple source
|
||||||
|
imgs = [
|
||||||
|
SOURCE, # filename
|
||||||
|
Path(SOURCE), # Path
|
||||||
|
'https://ultralytics.com/images/zidane.jpg', # URI
|
||||||
|
cv2.imread(str(SOURCE)), # OpenCV
|
||||||
|
Image.open(SOURCE), # PIL
|
||||||
|
np.zeros((320, 640, 3))] # numpy
|
||||||
|
output = model(imgs)
|
||||||
|
assert len(output) == 6, "predict test failed!"
|
||||||
|
|
||||||
|
|
||||||
def test_val():
|
def test_val():
|
||||||
@ -129,3 +142,28 @@ def test_workflow():
|
|||||||
model.val()
|
model.val()
|
||||||
model.predict(SOURCE)
|
model.predict(SOURCE)
|
||||||
model.export(format="onnx", opset=12) # export a model to ONNX format
|
model.export(format="onnx", opset=12) # export a model to ONNX format
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_callback_and_setup():
|
||||||
|
|
||||||
|
def on_predict_batch_end(predictor):
|
||||||
|
# results -> List[batch_size]
|
||||||
|
path, _, im0s, _, _ = predictor.batch
|
||||||
|
# print('on_predict_batch_end', im0s[0].shape)
|
||||||
|
bs = [predictor.bs for i in range(0, len(path))]
|
||||||
|
predictor.results = zip(predictor.results, im0s, bs)
|
||||||
|
|
||||||
|
model = YOLO("yolov8n.pt")
|
||||||
|
model.add_callback("on_predict_batch_end", on_predict_batch_end)
|
||||||
|
|
||||||
|
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
|
||||||
|
bs = dataset.bs # access predictor properties
|
||||||
|
results = model.predict(dataset, stream=True) # source already setup
|
||||||
|
for _, (result, im0, bs) in enumerate(results):
|
||||||
|
print('test_callback', im0.shape)
|
||||||
|
print('test_callback', bs)
|
||||||
|
boxes = result.boxes # Boxes object for bbox outputs
|
||||||
|
print(boxes)
|
||||||
|
|
||||||
|
|
||||||
|
test_predict_img()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .build import build_classification_dataloader, build_dataloader
|
from .build import build_classification_dataloader, build_dataloader, load_inference_source
|
||||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||||
from .dataset_wrappers import MixAndRectDataset
|
from .dataset_wrappers import MixAndRectDataset
|
||||||
|
@ -2,11 +2,18 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader, dataloader, distributed
|
from torch.utils.data import DataLoader, dataloader, distributed
|
||||||
|
|
||||||
|
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
|
||||||
|
LoadStreams, SourceTypes, autocast_list)
|
||||||
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
|
from ultralytics.yolo.utils.checks import check_file
|
||||||
|
|
||||||
from ..utils import LOGGER, colorstr
|
from ..utils import LOGGER, colorstr
|
||||||
from ..utils.torch_utils import torch_distributed_zero_first
|
from ..utils.torch_utils import torch_distributed_zero_first
|
||||||
from .dataset import ClassificationDataset, YOLODataset
|
from .dataset import ClassificationDataset, YOLODataset
|
||||||
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
|
|||||||
pin_memory=PIN_MEMORY,
|
pin_memory=PIN_MEMORY,
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
generator=generator) # or DataLoader(persistent_workers=True)
|
generator=generator) # or DataLoader(persistent_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_source(source):
|
||||||
|
webcam, screenshot, from_img, in_memory = False, False, False, False
|
||||||
|
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||||
|
source = str(source)
|
||||||
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||||
|
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||||
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||||
|
screenshot = source.lower().startswith('screen')
|
||||||
|
if is_url and is_file:
|
||||||
|
source = check_file(source) # download
|
||||||
|
elif isinstance(source, tuple(LOADERS)):
|
||||||
|
in_memory = True
|
||||||
|
elif isinstance(source, (list, tuple)):
|
||||||
|
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
||||||
|
from_img = True
|
||||||
|
elif isinstance(source, ((Image.Image, np.ndarray))):
|
||||||
|
from_img = True
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||||
|
|
||||||
|
return source, webcam, screenshot, from_img, in_memory
|
||||||
|
|
||||||
|
|
||||||
|
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
|
||||||
|
"""
|
||||||
|
TODO: docs
|
||||||
|
"""
|
||||||
|
# source
|
||||||
|
source, webcam, screenshot, from_img, in_memory = check_source(source)
|
||||||
|
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
|
||||||
|
|
||||||
|
# Dataloader
|
||||||
|
if in_memory:
|
||||||
|
dataset = source
|
||||||
|
elif webcam:
|
||||||
|
dataset = LoadStreams(source,
|
||||||
|
imgsz=imgsz,
|
||||||
|
stride=stride,
|
||||||
|
auto=auto,
|
||||||
|
transforms=transforms,
|
||||||
|
vid_stride=vid_stride)
|
||||||
|
|
||||||
|
elif screenshot:
|
||||||
|
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||||
|
elif from_img:
|
||||||
|
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||||
|
else:
|
||||||
|
dataset = LoadImages(source,
|
||||||
|
imgsz=imgsz,
|
||||||
|
stride=stride,
|
||||||
|
auto=auto,
|
||||||
|
transforms=transforms,
|
||||||
|
vid_stride=vid_stride)
|
||||||
|
|
||||||
|
setattr(dataset, 'source_type', source_type) # attach source types
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
@ -4,14 +4,16 @@ import glob
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from ultralytics.yolo.data.augment import LetterBox
|
from ultralytics.yolo.data.augment import LetterBox
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
|||||||
from ultralytics.yolo.utils.checks import check_requirements
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SourceTypes:
|
||||||
|
webcam: bool = False
|
||||||
|
screenshot: bool = False
|
||||||
|
from_img: bool = False
|
||||||
|
|
||||||
|
|
||||||
class LoadStreams:
|
class LoadStreams:
|
||||||
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||||
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||||
@ -63,6 +72,8 @@ class LoadStreams:
|
|||||||
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
||||||
self.auto = auto and self.rect
|
self.auto = auto and self.rect
|
||||||
self.transforms = transforms # optional
|
self.transforms = transforms # optional
|
||||||
|
self.bs = self.__len__()
|
||||||
|
|
||||||
if not self.rect:
|
if not self.rect:
|
||||||
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
||||||
|
|
||||||
@ -128,6 +139,7 @@ class LoadScreenshots:
|
|||||||
self.mode = 'stream'
|
self.mode = 'stream'
|
||||||
self.frame = 0
|
self.frame = 0
|
||||||
self.sct = mss.mss()
|
self.sct = mss.mss()
|
||||||
|
self.bs = 1
|
||||||
|
|
||||||
# Parse monitor shape
|
# Parse monitor shape
|
||||||
monitor = self.sct.monitors[self.screen]
|
monitor = self.sct.monitors[self.screen]
|
||||||
@ -185,6 +197,7 @@ class LoadImages:
|
|||||||
self.auto = auto
|
self.auto = auto
|
||||||
self.transforms = transforms # optional
|
self.transforms = transforms # optional
|
||||||
self.vid_stride = vid_stride # video frame-rate stride
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
|
self.bs = 1
|
||||||
if any(videos):
|
if any(videos):
|
||||||
self.orientation = None # rotation degrees
|
self.orientation = None # rotation degrees
|
||||||
self._new_video(videos[0]) # new video
|
self._new_video(videos[0]) # new video
|
||||||
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
|
|||||||
self.mode = 'image'
|
self.mode = 'image'
|
||||||
# generate fake paths
|
# generate fake paths
|
||||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
||||||
|
self.bs = 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _single_check(im):
|
def _single_check(im):
|
||||||
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def autocast_list(source):
|
||||||
|
"""
|
||||||
|
Merges a list of source of different types into a list of numpy arrays or PIL images
|
||||||
|
"""
|
||||||
|
files = []
|
||||||
|
for _, im in enumerate(source):
|
||||||
|
if isinstance(im, (str, Path)): # filename or uri
|
||||||
|
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
||||||
|
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||||
|
files.append(im)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
||||||
dataset = LoadPilAndNumpy(im0=img)
|
dataset = LoadPilAndNumpy(im0=img)
|
||||||
|
@ -233,6 +233,13 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.names
|
return self.model.names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transforms(self):
|
||||||
|
"""
|
||||||
|
Returns transform of the loaded model.
|
||||||
|
"""
|
||||||
|
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_callback(event: str, func):
|
def add_callback(event: str, func):
|
||||||
"""
|
"""
|
||||||
|
@ -30,13 +30,13 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
from ultralytics.yolo.data import load_inference_source
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
@ -76,6 +76,8 @@ class BasePredictor:
|
|||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.25 # default conf=0.25
|
self.args.conf = 0.25 # default conf=0.25
|
||||||
self.done_warmup = False
|
self.done_warmup = False
|
||||||
|
if self.args.show:
|
||||||
|
self.args.show = check_imshow(warn=True)
|
||||||
|
|
||||||
# Usable if setup is done
|
# Usable if setup is done
|
||||||
self.model = None
|
self.model = None
|
||||||
@ -88,6 +90,7 @@ class BasePredictor:
|
|||||||
self.vid_path, self.vid_writer = None, None
|
self.vid_path, self.vid_writer = None, None
|
||||||
self.annotator = None
|
self.annotator = None
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
|
self.source_type = None
|
||||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
@ -103,53 +106,6 @@ class BasePredictor:
|
|||||||
def postprocess(self, preds, img, orig_img, classes=None):
|
def postprocess(self, preds, img, orig_img, classes=None):
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def setup_source(self, source=None):
|
|
||||||
if not self.model:
|
|
||||||
raise Exception("setup model before setting up source!")
|
|
||||||
# source
|
|
||||||
source, webcam, screenshot, from_img = self.check_source(source)
|
|
||||||
# model
|
|
||||||
stride, pt = self.model.stride, self.model.pt
|
|
||||||
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
|
||||||
|
|
||||||
# Dataloader
|
|
||||||
bs = 1 # batch_size
|
|
||||||
if webcam:
|
|
||||||
self.args.show = check_imshow(warn=True)
|
|
||||||
self.dataset = LoadStreams(source,
|
|
||||||
imgsz=imgsz,
|
|
||||||
stride=stride,
|
|
||||||
auto=pt,
|
|
||||||
transforms=getattr(self.model.model, 'transforms', None),
|
|
||||||
vid_stride=self.args.vid_stride)
|
|
||||||
bs = len(self.dataset)
|
|
||||||
elif screenshot:
|
|
||||||
self.dataset = LoadScreenshots(source,
|
|
||||||
imgsz=imgsz,
|
|
||||||
stride=stride,
|
|
||||||
auto=pt,
|
|
||||||
transforms=getattr(self.model.model, 'transforms', None))
|
|
||||||
elif from_img:
|
|
||||||
self.dataset = LoadPilAndNumpy(source,
|
|
||||||
imgsz=imgsz,
|
|
||||||
stride=stride,
|
|
||||||
auto=pt,
|
|
||||||
transforms=getattr(self.model.model, 'transforms', None))
|
|
||||||
else:
|
|
||||||
self.dataset = LoadImages(source,
|
|
||||||
imgsz=imgsz,
|
|
||||||
stride=stride,
|
|
||||||
auto=pt,
|
|
||||||
transforms=getattr(self.model.model, 'transforms', None),
|
|
||||||
vid_stride=self.args.vid_stride)
|
|
||||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
|
||||||
|
|
||||||
self.webcam = webcam
|
|
||||||
self.screenshot = screenshot
|
|
||||||
self.from_img = from_img
|
|
||||||
self.imgsz = imgsz
|
|
||||||
self.bs = bs
|
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, source=None, model=None, stream=False):
|
def __call__(self, source=None, model=None, stream=False):
|
||||||
if stream:
|
if stream:
|
||||||
@ -163,14 +119,29 @@ class BasePredictor:
|
|||||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup_source(self, source):
|
||||||
|
if not self.model:
|
||||||
|
raise Exception("Model not initialized!")
|
||||||
|
|
||||||
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||||
|
self.dataset = load_inference_source(source=source,
|
||||||
|
transforms=getattr(self.model.model, 'transforms', None),
|
||||||
|
imgsz=self.imgsz,
|
||||||
|
vid_stride=self.args.vid_stride,
|
||||||
|
stride=self.model.stride,
|
||||||
|
auto=self.model.pt)
|
||||||
|
self.source_type = self.dataset.source_type
|
||||||
|
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||||
|
|
||||||
def stream_inference(self, source=None, model=None):
|
def stream_inference(self, source=None, model=None):
|
||||||
self.run_callbacks("on_predict_start")
|
self.run_callbacks("on_predict_start")
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
if not self.model:
|
if not self.model:
|
||||||
self.setup_model(model)
|
self.setup_model(model)
|
||||||
# setup source. Run every time predict is called
|
# setup source every time predict is called
|
||||||
self.setup_source(source)
|
self.setup_source(source if source is not None else self.args.source)
|
||||||
|
|
||||||
# check if save_dir/ label file exists
|
# check if save_dir/ label file exists
|
||||||
if self.args.save or self.args.save_txt:
|
if self.args.save or self.args.save_txt:
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
@ -198,7 +169,7 @@ class BasePredictor:
|
|||||||
with self.dt[2]:
|
with self.dt[2]:
|
||||||
self.results = self.postprocess(preds, im, im0s, self.classes)
|
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||||
for i in range(len(im)):
|
for i in range(len(im)):
|
||||||
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
|
||||||
p = Path(p)
|
p = Path(p)
|
||||||
|
|
||||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||||
@ -237,21 +208,6 @@ class BasePredictor:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def check_source(self, source):
|
|
||||||
source = source if source is not None else self.args.source
|
|
||||||
webcam, screenshot, from_img = False, False, False
|
|
||||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
|
||||||
source = str(source)
|
|
||||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
|
||||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
|
||||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
|
||||||
screenshot = source.lower().startswith('screen')
|
|
||||||
if is_url and is_file:
|
|
||||||
source = check_file(source) # download
|
|
||||||
else:
|
|
||||||
from_img = True
|
|
||||||
return source, webcam, screenshot, from_img
|
|
||||||
|
|
||||||
def show(self, p):
|
def show(self, p):
|
||||||
im0 = self.annotator.result()
|
im0 = self.annotator.result()
|
||||||
if platform.system() == 'Linux' and p not in self.windows:
|
if platform.system() == 'Linux' and p not in self.windows:
|
||||||
|
@ -33,7 +33,7 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
im0 = im0.copy()
|
im0 = im0.copy()
|
||||||
if self.webcam or self.from_img: # batch_size >= 1
|
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||||
log_string += f'{idx}: '
|
log_string += f'{idx}: '
|
||||||
frame = self.dataset.count
|
frame = self.dataset.count
|
||||||
else:
|
else:
|
||||||
|
@ -42,7 +42,7 @@ class DetectionPredictor(BasePredictor):
|
|||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
im0 = im0.copy()
|
im0 = im0.copy()
|
||||||
if self.webcam or self.from_img: # batch_size >= 1
|
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||||
log_string += f'{idx}: '
|
log_string += f'{idx}: '
|
||||||
frame = self.dataset.count
|
frame = self.dataset.count
|
||||||
else:
|
else:
|
||||||
|
@ -43,7 +43,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
if self.webcam or self.from_img: # batch_size >= 1
|
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||||
log_string += f'{idx}: '
|
log_string += f'{idx}: '
|
||||||
frame = self.dataset.count
|
frame = self.dataset.count
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user