mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
479992093c
commit
e6737f1207
BIN
ultralytics/assets/bus.jpg
Normal file
BIN
ultralytics/assets/bus.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 476 KiB |
BIN
ultralytics/assets/zidane.jpg
Normal file
BIN
ultralytics/assets/zidane.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 165 KiB |
@ -39,9 +39,9 @@ def cli(cfg):
|
|||||||
module_function = module_file.train
|
module_function = module_file.train
|
||||||
elif cfg.mode.lower() == "val":
|
elif cfg.mode.lower() == "val":
|
||||||
module_function = module_file.val
|
module_function = module_file.val
|
||||||
elif cfg.mode.lower() == "infer":
|
elif cfg.mode.lower() == "predict":
|
||||||
module_function = module_file.infer
|
module_function = module_file.predict
|
||||||
|
|
||||||
if not module_function:
|
if not module_function:
|
||||||
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`")
|
raise Exception("mode not recognized. Choices are `'train', 'val', 'predict'`")
|
||||||
module_function(cfg)
|
module_function(cfg)
|
||||||
|
@ -459,7 +459,7 @@ class LetterBox:
|
|||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
def __call__(self, labels={}, image=None):
|
def __call__(self, labels={}, image=None):
|
||||||
img = image or labels["img"]
|
img = labels.get("img") if image is None else image
|
||||||
shape = img.shape[:2] # current shape [height, width]
|
shape = img.shape[:2] # current shape [height, width]
|
||||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||||
if isinstance(new_shape, int):
|
if isinstance(new_shape, int):
|
||||||
@ -491,10 +491,13 @@ class LetterBox:
|
|||||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||||
value=(114, 114, 114)) # add border
|
value=(114, 114, 114)) # add border
|
||||||
|
|
||||||
labels = self._update_labels(labels, ratio, dw, dh)
|
if len(labels):
|
||||||
labels["img"] = img
|
labels = self._update_labels(labels, ratio, dw, dh)
|
||||||
labels["resized_shape"] = new_shape
|
labels["img"] = img
|
||||||
return labels
|
labels["resized_shape"] = new_shape
|
||||||
|
return labels
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
def _update_labels(self, labels, ratio, padw, padh):
|
def _update_labels(self, labels, ratio, padw, padh):
|
||||||
"""Update labels"""
|
"""Update labels"""
|
||||||
|
254
ultralytics/yolo/data/dataloaders/stream_loaders.py
Normal file
254
ultralytics/yolo/data/dataloaders/stream_loaders.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import glob
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.data.augment import LetterBox
|
||||||
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
|
from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops
|
||||||
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStreams:
|
||||||
|
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||||
|
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||||
|
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||||
|
self.mode = 'stream'
|
||||||
|
self.img_size = img_size
|
||||||
|
self.stride = stride
|
||||||
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
|
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
||||||
|
n = len(sources)
|
||||||
|
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||||||
|
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
||||||
|
for i, s in enumerate(sources): # index, source
|
||||||
|
# Start thread to read frames from video stream
|
||||||
|
st = f'{i + 1}/{n}: {s}... '
|
||||||
|
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||||
|
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||||
|
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||||
|
import pafy
|
||||||
|
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||||
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||||
|
if s == 0:
|
||||||
|
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
||||||
|
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
|
||||||
|
cap = cv2.VideoCapture(s)
|
||||||
|
assert cap.isOpened(), f'{st}Failed to open {s}'
|
||||||
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||||
|
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
|
||||||
|
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||||
|
|
||||||
|
_, self.imgs[i] = cap.read() # guarantee first frame
|
||||||
|
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||||
|
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||||
|
self.threads[i].start()
|
||||||
|
LOGGER.info('') # newline
|
||||||
|
|
||||||
|
# check for common shapes
|
||||||
|
s = np.stack([LetterBox(img_size, auto, stride=stride)(image=x).shape for x in self.imgs])
|
||||||
|
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
||||||
|
self.auto = auto and self.rect
|
||||||
|
self.transforms = transforms # optional
|
||||||
|
if not self.rect:
|
||||||
|
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
||||||
|
|
||||||
|
def update(self, i, cap, stream):
|
||||||
|
# Read stream `i` frames in daemon thread
|
||||||
|
n, f = 0, self.frames[i] # frame number, frame array
|
||||||
|
while cap.isOpened() and n < f:
|
||||||
|
n += 1
|
||||||
|
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||||
|
if n % self.vid_stride == 0:
|
||||||
|
success, im = cap.retrieve()
|
||||||
|
if success:
|
||||||
|
self.imgs[i] = im
|
||||||
|
else:
|
||||||
|
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
||||||
|
self.imgs[i] = np.zeros_like(self.imgs[i])
|
||||||
|
cap.open(stream) # re-open stream if signal was lost
|
||||||
|
time.sleep(0.0) # wait time
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.count = -1
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
self.count += 1
|
||||||
|
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
im0 = self.imgs.copy()
|
||||||
|
if self.transforms:
|
||||||
|
im = np.stack([self.transforms(x) for x in im0]) # transforms
|
||||||
|
else:
|
||||||
|
im = np.stack([LetterBox(self.img_size, self.auto, stride=self.stride)(image=x) for x in im0])
|
||||||
|
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
||||||
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
|
|
||||||
|
return self.sources, im, im0, None, ''
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||||
|
|
||||||
|
|
||||||
|
class LoadScreenshots:
|
||||||
|
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
|
||||||
|
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
|
||||||
|
# source = [screen_number left top width height] (pixels)
|
||||||
|
check_requirements('mss')
|
||||||
|
import mss
|
||||||
|
|
||||||
|
source, *params = source.split()
|
||||||
|
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||||
|
if len(params) == 1:
|
||||||
|
self.screen = int(params[0])
|
||||||
|
elif len(params) == 4:
|
||||||
|
left, top, width, height = (int(x) for x in params)
|
||||||
|
elif len(params) == 5:
|
||||||
|
self.screen, left, top, width, height = (int(x) for x in params)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.stride = stride
|
||||||
|
self.transforms = transforms
|
||||||
|
self.auto = auto
|
||||||
|
self.mode = 'stream'
|
||||||
|
self.frame = 0
|
||||||
|
self.sct = mss.mss()
|
||||||
|
|
||||||
|
# Parse monitor shape
|
||||||
|
monitor = self.sct.monitors[self.screen]
|
||||||
|
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||||
|
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||||
|
self.width = width or monitor["width"]
|
||||||
|
self.height = height or monitor["height"]
|
||||||
|
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# mss screen capture: get raw pixels from the screen as np array
|
||||||
|
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||||
|
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||||
|
|
||||||
|
if self.transforms:
|
||||||
|
im = self.transforms(im0) # transforms
|
||||||
|
else:
|
||||||
|
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
|
||||||
|
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||||
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
|
self.frame += 1
|
||||||
|
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImages:
|
||||||
|
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||||
|
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||||
|
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||||
|
path = Path(path).read_text().rsplit()
|
||||||
|
files = []
|
||||||
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||||
|
p = str(Path(p).resolve())
|
||||||
|
if '*' in p:
|
||||||
|
files.extend(sorted(glob.glob(p, recursive=True))) # glob
|
||||||
|
elif os.path.isdir(p):
|
||||||
|
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
|
||||||
|
elif os.path.isfile(p):
|
||||||
|
files.append(p) # files
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f'{p} does not exist')
|
||||||
|
|
||||||
|
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
||||||
|
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
||||||
|
ni, nv = len(images), len(videos)
|
||||||
|
|
||||||
|
self.img_size = img_size
|
||||||
|
self.stride = stride
|
||||||
|
self.files = images + videos
|
||||||
|
self.nf = ni + nv # number of files
|
||||||
|
self.video_flag = [False] * ni + [True] * nv
|
||||||
|
self.mode = 'image'
|
||||||
|
self.auto = auto
|
||||||
|
self.transforms = transforms # optional
|
||||||
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
|
if any(videos):
|
||||||
|
self._new_video(videos[0]) # new video
|
||||||
|
else:
|
||||||
|
self.cap = None
|
||||||
|
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
||||||
|
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.count = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.count == self.nf:
|
||||||
|
raise StopIteration
|
||||||
|
path = self.files[self.count]
|
||||||
|
|
||||||
|
if self.video_flag[self.count]:
|
||||||
|
# Read video
|
||||||
|
self.mode = 'video'
|
||||||
|
for _ in range(self.vid_stride):
|
||||||
|
self.cap.grab()
|
||||||
|
ret_val, im0 = self.cap.retrieve()
|
||||||
|
while not ret_val:
|
||||||
|
self.count += 1
|
||||||
|
self.cap.release()
|
||||||
|
if self.count == self.nf: # last video
|
||||||
|
raise StopIteration
|
||||||
|
path = self.files[self.count]
|
||||||
|
self._new_video(path)
|
||||||
|
ret_val, im0 = self.cap.read()
|
||||||
|
|
||||||
|
self.frame += 1
|
||||||
|
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
||||||
|
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Read image
|
||||||
|
self.count += 1
|
||||||
|
im0 = cv2.imread(path) # BGR
|
||||||
|
assert im0 is not None, f'Image Not Found {path}'
|
||||||
|
s = f'image {self.count}/{self.nf} {path}: '
|
||||||
|
|
||||||
|
if self.transforms:
|
||||||
|
im = self.transforms(im0) # transforms
|
||||||
|
else:
|
||||||
|
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
|
||||||
|
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||||
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
|
|
||||||
|
return path, im, im0, self.cap, s
|
||||||
|
|
||||||
|
def _new_video(self, path):
|
||||||
|
# Create a new video capture object
|
||||||
|
self.frame = 0
|
||||||
|
self.cap = cv2.VideoCapture(path)
|
||||||
|
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
||||||
|
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
||||||
|
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
|
||||||
|
|
||||||
|
def _cv2_rotate(self, im):
|
||||||
|
# Rotate a cv2 video manually
|
||||||
|
if self.orientation == 0:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
elif self.orientation == 180:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||||
|
elif self.orientation == 90:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_180)
|
||||||
|
return im
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.nf # number of files
|
@ -0,0 +1,201 @@
|
|||||||
|
# predictor engine by Ultralytics
|
||||||
|
"""
|
||||||
|
Run prection on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||||
|
Usage - sources:
|
||||||
|
$ yolo task=... mode=predict model=s.pt --source 0 # webcam
|
||||||
|
img.jpg # image
|
||||||
|
vid.mp4 # video
|
||||||
|
screen # screenshot
|
||||||
|
path/ # directory
|
||||||
|
list.txt # list of images
|
||||||
|
list.streams # list of streams
|
||||||
|
'path/*.jpg' # glob
|
||||||
|
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||||
|
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||||
|
Usage - formats:
|
||||||
|
$ yolo task=... mode=predict --weights yolov5s.pt # PyTorch
|
||||||
|
yolov5s.torchscript # TorchScript
|
||||||
|
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||||
|
yolov5s_openvino_model # OpenVINO
|
||||||
|
yolov5s.engine # TensorRT
|
||||||
|
yolov5s.mlmodel # CoreML (macOS-only)
|
||||||
|
yolov5s_saved_model # TensorFlow SavedModel
|
||||||
|
yolov5s.pb # TensorFlow GraphDef
|
||||||
|
yolov5s.tflite # TensorFlow Lite
|
||||||
|
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
||||||
|
yolov5s_paddle_model # PaddlePaddle
|
||||||
|
"""
|
||||||
|
import platform
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||||
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml
|
||||||
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr, ops
|
||||||
|
from ultralytics.yolo.utils.checks import check_file, check_imshow
|
||||||
|
from ultralytics.yolo.utils.configs import get_config
|
||||||
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
|
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
|
||||||
|
from ultralytics.yolo.utils.plotting import Annotator
|
||||||
|
from ultralytics.yolo.utils.torch_utils import check_img_size, select_device, smart_inference_mode
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
class BasePredictor:
|
||||||
|
|
||||||
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||||
|
self.args = get_config(config, overrides)
|
||||||
|
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||||
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.done_setup = False
|
||||||
|
|
||||||
|
# Usable if setup is done
|
||||||
|
self.model = None
|
||||||
|
self.data = self.args.data # data_dict
|
||||||
|
self.device = None
|
||||||
|
self.dataset = None
|
||||||
|
self.vid_path, self.vid_writer = None, None
|
||||||
|
self.view_img = None
|
||||||
|
self.annotator = None
|
||||||
|
self.data_path = None
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_annotator(self, img):
|
||||||
|
raise NotImplementedError("get_annotator function needs to be implemented")
|
||||||
|
|
||||||
|
def write_results(self, pred, batch, print_string):
|
||||||
|
raise NotImplementedError("print_results function needs to be implemented")
|
||||||
|
|
||||||
|
def postprocess(self, preds, img, orig_img):
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def setup(self, source=None, model=None):
|
||||||
|
# source
|
||||||
|
source = str(source or self.args.source)
|
||||||
|
self.save_img = not self.args.nosave and not source.endswith('.txt')
|
||||||
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||||
|
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||||
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||||
|
screenshot = source.lower().startswith('screen')
|
||||||
|
if is_url and is_file:
|
||||||
|
source = check_file(source) # download
|
||||||
|
|
||||||
|
# data
|
||||||
|
if self.data:
|
||||||
|
if self.data.endswith(".yaml"):
|
||||||
|
self.data = check_dataset_yaml(self.data)
|
||||||
|
else:
|
||||||
|
self.data = check_dataset(self.data)
|
||||||
|
|
||||||
|
# model
|
||||||
|
device = select_device(self.args.device)
|
||||||
|
model = model or self.args.model
|
||||||
|
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||||
|
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) # NOTE: not passing data
|
||||||
|
stride, pt = model.stride, model.pt
|
||||||
|
imgsz = check_img_size(self.args.img_size, s=stride) # check image size
|
||||||
|
|
||||||
|
# Dataloader
|
||||||
|
bs = 1 # batch_size
|
||||||
|
if webcam:
|
||||||
|
self.view_img = check_imshow(warn=True)
|
||||||
|
self.dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
|
||||||
|
bs = len(self.dataset)
|
||||||
|
elif screenshot:
|
||||||
|
self.dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||||
|
else:
|
||||||
|
self.dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
|
||||||
|
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
||||||
|
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.webcam = webcam
|
||||||
|
self.screenshot = screenshot
|
||||||
|
self.imgsz = imgsz
|
||||||
|
self.done_setup = True
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
|
def __call__(self, source=None, model=None):
|
||||||
|
if not self.done_setup:
|
||||||
|
model = self.setup(source, model)
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||||
|
for batch in self.dataset:
|
||||||
|
path, im, im0s, vid_cap, s = batch
|
||||||
|
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||||
|
with self.dt[0]:
|
||||||
|
im = self.preprocess(im)
|
||||||
|
if len(im.shape) == 3:
|
||||||
|
im = im[None] # expand for batch dim
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with self.dt[1]:
|
||||||
|
preds = model(im, augment=self.args.augment, visualize=visualize)
|
||||||
|
|
||||||
|
# postprocess
|
||||||
|
with self.dt[2]:
|
||||||
|
preds = self.postprocess(preds, im, im0s)
|
||||||
|
|
||||||
|
for i in range(len(im)):
|
||||||
|
if self.webcam:
|
||||||
|
path, im0s = path[i], im0s[i]
|
||||||
|
p = Path(path)
|
||||||
|
s += self.write_results(i, preds, (p, im, im0s))
|
||||||
|
|
||||||
|
if self.args.view_img:
|
||||||
|
self.show(p)
|
||||||
|
|
||||||
|
if self.save_img:
|
||||||
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||||
|
|
||||||
|
# Print time (inference-only)
|
||||||
|
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||||
|
LOGGER.info(
|
||||||
|
f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}'
|
||||||
|
% t)
|
||||||
|
if self.args.save_txt or self.save_img:
|
||||||
|
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||||
|
|
||||||
|
def show(self, p):
|
||||||
|
im0 = self.annotator.result()
|
||||||
|
if platform.system() == 'Linux' and p not in self.windows:
|
||||||
|
self.windows.append(p)
|
||||||
|
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||||
|
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||||
|
cv2.imshow(str(p), im0)
|
||||||
|
cv2.waitKey(1) # 1 millisecond
|
||||||
|
|
||||||
|
def save_preds(self, vid_cap, idx, save_path):
|
||||||
|
im0 = self.annotator.result()
|
||||||
|
# save imgs
|
||||||
|
if self.dataset.mode == 'image':
|
||||||
|
cv2.imwrite(save_path, im0)
|
||||||
|
else: # 'video' or 'stream'
|
||||||
|
if self.vid_path[idx] != save_path: # new video
|
||||||
|
self.vid_path[idx] = save_path
|
||||||
|
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
||||||
|
self.vid_writer[idx].release() # release previous video writer
|
||||||
|
if vid_cap: # video
|
||||||
|
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
else: # stream
|
||||||
|
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||||
|
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||||
|
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
self.vid_writer[idx].write(im0)
|
@ -15,7 +15,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
@ -26,7 +26,9 @@ import ultralytics.yolo.utils.callbacks as callbacks
|
|||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
|
from ultralytics.yolo.utils.configs import get_config
|
||||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
||||||
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
@ -36,7 +38,7 @@ RANK = int(os.getenv('RANK', -1))
|
|||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||||
self.args = self._get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
self.check_resume()
|
self.check_resume()
|
||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
@ -84,25 +86,6 @@ class BaseTrainer:
|
|||||||
self.add_callback(callback, func)
|
self.add_callback(callback, func)
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
|
||||||
"""
|
|
||||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
|
||||||
Returns training args namespace
|
|
||||||
:param config: Optional file name or DictConfig object
|
|
||||||
"""
|
|
||||||
if isinstance(config, (str, Path)):
|
|
||||||
config = OmegaConf.load(config)
|
|
||||||
elif isinstance(config, Dict):
|
|
||||||
config = OmegaConf.create(config)
|
|
||||||
|
|
||||||
# override
|
|
||||||
if isinstance(overrides, str):
|
|
||||||
overrides = OmegaConf.load(overrides)
|
|
||||||
elif isinstance(overrides, Dict):
|
|
||||||
overrides = OmegaConf.create(overrides)
|
|
||||||
|
|
||||||
return OmegaConf.merge(config, overrides)
|
|
||||||
|
|
||||||
def add_callback(self, onevent: str, callback):
|
def add_callback(self, onevent: str, callback):
|
||||||
"""
|
"""
|
||||||
appends the given callback
|
appends the given callback
|
||||||
|
@ -46,8 +46,8 @@ class BaseValidator:
|
|||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
model = model.half() if self.args.half else model.float()
|
model = model.half() if self.args.half else model.float()
|
||||||
self.model = model
|
self.model = model
|
||||||
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||||
else: # TODO: handle this when detectMultiBackend is supported
|
else:
|
||||||
assert model is not None, "Either trainer or model is needed for validation"
|
assert model is not None, "Either trainer or model is needed for validation"
|
||||||
self.device = select_device(self.args.device, self.args.batch_size)
|
self.device = select_device(self.args.device, self.args.batch_size)
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
@ -90,13 +90,11 @@ class BaseValidator:
|
|||||||
# inference
|
# inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
preds = model(batch["img"])
|
preds = model(batch["img"])
|
||||||
# TODO: remember to add native augmentation support when implementing model, like:
|
|
||||||
# preds, train_out = model(im, augment=augment)
|
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
if self.training:
|
if self.training:
|
||||||
loss += trainer.criterion(preds, batch)[1]
|
self.loss += trainer.criterion(preds, batch)[1]
|
||||||
|
|
||||||
# pre-process predictions
|
# pre-process predictions
|
||||||
with dt[3]:
|
with dt[3]:
|
||||||
@ -123,7 +121,7 @@ class BaseValidator:
|
|||||||
model.float()
|
model.float()
|
||||||
# TODO: implement save json
|
# TODO: implement save json
|
||||||
|
|
||||||
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
return stats | trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val") \
|
||||||
if self.training else stats
|
if self.training else stats
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
@ -6,6 +6,8 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[2] # YOLO
|
ROOT = FILE.parents[2] # YOLO
|
||||||
@ -29,6 +31,23 @@ def is_kaggle():
|
|||||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||||
|
|
||||||
|
|
||||||
|
def is_notebook():
|
||||||
|
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
|
||||||
|
ipython_type = str(type(IPython.get_ipython()))
|
||||||
|
return 'colab' in ipython_type or 'zmqshell' in ipython_type
|
||||||
|
|
||||||
|
|
||||||
|
def is_docker() -> bool:
|
||||||
|
"""Check if the process runs inside a docker container."""
|
||||||
|
if Path("/.dockerenv").exists():
|
||||||
|
return True
|
||||||
|
try: # check if docker is in control groups
|
||||||
|
with open("/proc/self/cgroup") as file:
|
||||||
|
return any("docker" in line for line in file)
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_writeable(dir, test=False):
|
def is_writeable(dir, test=False):
|
||||||
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
||||||
if not test:
|
if not test:
|
||||||
|
@ -6,10 +6,13 @@ from pathlib import Path
|
|||||||
from subprocess import check_output
|
from subprocess import check_output
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.utils import AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis
|
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||||
|
is_docker, is_notebook)
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s=''):
|
def is_ascii(s=''):
|
||||||
@ -131,6 +134,22 @@ def check_yaml(file, suffix=('.yaml', '.yml')):
|
|||||||
return check_file(file, suffix)
|
return check_file(file, suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def check_imshow(warn=False):
|
||||||
|
# Check if environment supports image displays
|
||||||
|
try:
|
||||||
|
assert not is_notebook()
|
||||||
|
assert not is_docker()
|
||||||
|
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
cv2.waitKey(1)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
if warn:
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def git_describe(path=ROOT): # path must be a directory
|
def git_describe(path=ROOT): # path must be a directory
|
||||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||||
try:
|
try:
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||||
|
"""
|
||||||
|
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||||
|
Returns training args namespace
|
||||||
|
:param config: Optional file name or DictConfig object
|
||||||
|
"""
|
||||||
|
if isinstance(config, (str, Path)):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
elif isinstance(config, Dict):
|
||||||
|
config = OmegaConf.create(config)
|
||||||
|
# override
|
||||||
|
if isinstance(overrides, str):
|
||||||
|
overrides = OmegaConf.load(overrides)
|
||||||
|
elif isinstance(overrides, Dict):
|
||||||
|
overrides = OmegaConf.create(overrides)
|
||||||
|
|
||||||
|
return OmegaConf.merge(config, overrides)
|
@ -46,7 +46,22 @@ max_det: 300
|
|||||||
half: True
|
half: True
|
||||||
dnn: False # use OpenCV DNN for ONNX inference
|
dnn: False # use OpenCV DNN for ONNX inference
|
||||||
plots: False
|
plots: False
|
||||||
|
|
||||||
|
# Prediction settings:
|
||||||
|
source: "ultralytics/assets/"
|
||||||
|
view_img: False
|
||||||
save_txt: False
|
save_txt: False
|
||||||
|
save_conf: False
|
||||||
|
save_crop: False
|
||||||
|
hide_labels: False # hide labels
|
||||||
|
hide_conf: False
|
||||||
|
vid_stride: 1 # video frame-rate stride
|
||||||
|
line_thickness: 3 # bounding box thickness (pixels)
|
||||||
|
update: False # Update all models
|
||||||
|
visualize: False
|
||||||
|
augment: False
|
||||||
|
agnostic_nms: False # class-agnostic NMS
|
||||||
|
retina_masks: False
|
||||||
|
|
||||||
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||||
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import yaml
|
|
||||||
|
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
from ultralytics.yolo.utils.modeling.modules import *
|
from ultralytics.yolo.utils.modeling.modules import *
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -374,3 +375,75 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
|||||||
if upsample:
|
if upsample:
|
||||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||||
return masks.gt_(0.5)
|
return masks.gt_(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||||
|
"""
|
||||||
|
Crop after upsample.
|
||||||
|
protos: [mask_dim, mask_h, mask_w]
|
||||||
|
masks_in: [n, mask_dim], n is number of masks after nms
|
||||||
|
bboxes: [n, 4], n is number of masks after nms
|
||||||
|
shape: input_image_size, (h, w)
|
||||||
|
return: h, w, n
|
||||||
|
"""
|
||||||
|
c, mh, mw = protos.shape # CHW
|
||||||
|
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||||
|
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||||
|
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
|
||||||
|
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||||
|
bottom, right = int(mh - pad[1]), int(mw - pad[0])
|
||||||
|
masks = masks[:, top:bottom, left:right]
|
||||||
|
|
||||||
|
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||||
|
masks = crop_mask(masks, bboxes) # CHW
|
||||||
|
return masks.gt_(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
|
||||||
|
# Rescale coords (xyxy) from img1_shape to img0_shape
|
||||||
|
if ratio_pad is None: # calculate from img0_shape
|
||||||
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||||
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||||
|
else:
|
||||||
|
gain = ratio_pad[0][0]
|
||||||
|
pad = ratio_pad[1]
|
||||||
|
|
||||||
|
segments[:, 0] -= pad[0] # x padding
|
||||||
|
segments[:, 1] -= pad[1] # y padding
|
||||||
|
segments /= gain
|
||||||
|
clip_segments(segments, img0_shape)
|
||||||
|
if normalize:
|
||||||
|
segments[:, 0] /= img0_shape[1] # width
|
||||||
|
segments[:, 1] /= img0_shape[0] # height
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def masks2segments(masks, strategy='largest'):
|
||||||
|
# Convert masks(n,160,160) into segments(n,xy)
|
||||||
|
segments = []
|
||||||
|
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||||
|
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||||
|
if c:
|
||||||
|
if strategy == 'concat': # concatenate all segments
|
||||||
|
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||||
|
elif strategy == 'largest': # select largest segment
|
||||||
|
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||||
|
else:
|
||||||
|
c = np.zeros((0, 2)) # no segments found
|
||||||
|
segments.append(c.astype('float32'))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def clip_segments(segments, shape):
|
||||||
|
# Clip segments (xy1,xy2,...) to image shape (height, width)
|
||||||
|
if isinstance(segments, torch.Tensor): # faster individually
|
||||||
|
segments[:, 0].clamp_(0, shape[1]) # x
|
||||||
|
segments[:, 1].clamp_(0, shape[0]) # y
|
||||||
|
else: # np.array (faster grouped)
|
||||||
|
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
|
||||||
|
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
|
||||||
|
|
||||||
|
|
||||||
|
def clean_str(s):
|
||||||
|
# Cleans a string by replacing special characters with underscore _
|
||||||
|
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||||
|
@ -36,6 +36,14 @@ def torch_distributed_zero_first(local_rank: int):
|
|||||||
dist.barrier(device_ids=[0])
|
dist.barrier(device_ids=[0])
|
||||||
|
|
||||||
|
|
||||||
|
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||||
|
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||||
|
def decorate(fn):
|
||||||
|
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||||
|
|
||||||
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
def DDP_model(model):
|
def DDP_model(model):
|
||||||
# Model DDP creation with checks
|
# Model DDP creation with checks
|
||||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||||
@ -192,14 +200,6 @@ def copy_attr(a, b, include=(), exclude=()):
|
|||||||
setattr(a, k, v)
|
setattr(a, k, v)
|
||||||
|
|
||||||
|
|
||||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
|
||||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
|
||||||
def decorate(fn):
|
|
||||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
|
||||||
|
|
||||||
return decorate
|
|
||||||
|
|
||||||
|
|
||||||
def intersect_state_dicts(da, db, exclude=()):
|
def intersect_state_dicts(da, db, exclude=()):
|
||||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
|
||||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
||||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
||||||
|
|
||||||
__all__ = ["train"]
|
|
||||||
|
68
ultralytics/yolo/v8/classify/predict.py
Normal file
68
ultralytics/yolo/v8/classify/predict.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
|
from ultralytics.yolo.utils import ops
|
||||||
|
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPredictor(BasePredictor):
|
||||||
|
|
||||||
|
def get_annotator(self, img):
|
||||||
|
return Annotator(img, example=str(self.model.names), pil=True)
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
img = torch.Tensor(img).to(self.model.device)
|
||||||
|
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||||
|
return img
|
||||||
|
|
||||||
|
def write_results(self, idx, preds, batch):
|
||||||
|
p, im, im0 = batch
|
||||||
|
log_string = ""
|
||||||
|
if len(im.shape) == 3:
|
||||||
|
im = im[None] # expand for batch dim
|
||||||
|
self.seen += 1
|
||||||
|
im0 = im0.copy()
|
||||||
|
if self.webcam: # batch_size >= 1
|
||||||
|
log_string += f'{idx}: '
|
||||||
|
frame = self.dataset.cound
|
||||||
|
else:
|
||||||
|
frame = getattr(self.dataset, 'frame', 0)
|
||||||
|
|
||||||
|
self.data_path = p
|
||||||
|
# save_path = str(self.save_dir / p.name) # im.jpg
|
||||||
|
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||||
|
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||||
|
self.annotator = self.get_annotator(im0)
|
||||||
|
|
||||||
|
prob = preds[idx]
|
||||||
|
# Print results
|
||||||
|
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||||
|
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
|
||||||
|
|
||||||
|
# write
|
||||||
|
text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i)
|
||||||
|
if self.save_img or self.args.view_img: # Add bbox to image
|
||||||
|
self.annotator.text((32, 32), text, txt_color=(255, 255, 255))
|
||||||
|
if self.args.save_txt: # Write to file
|
||||||
|
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||||
|
f.write(text + '\n')
|
||||||
|
|
||||||
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def predict(cfg):
|
||||||
|
cfg.model = cfg.model or "squeezenet1_0"
|
||||||
|
sz = cfg.img_size
|
||||||
|
if type(sz) != int: # recieved listConfig
|
||||||
|
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||||
|
else:
|
||||||
|
cfg.img_size = [sz, sz]
|
||||||
|
predictor = ClassificationPredictor(cfg)
|
||||||
|
predictor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predict()
|
@ -1,2 +1,3 @@
|
|||||||
|
from ultralytics.yolo.v8.detect.predict import DetectionPredictor, predict
|
||||||
from ultralytics.yolo.v8.detect.train import DetectionTrainer, train
|
from ultralytics.yolo.v8.detect.train import DetectionTrainer, train
|
||||||
from ultralytics.yolo.v8.detect.val import DetectionValidator, val
|
from ultralytics.yolo.v8.detect.val import DetectionValidator, val
|
||||||
|
97
ultralytics/yolo/v8/detect/predict.py
Normal file
97
ultralytics/yolo/v8/detect/predict.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
|
from ultralytics.yolo.utils import ops
|
||||||
|
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPredictor(BasePredictor):
|
||||||
|
|
||||||
|
def get_annotator(self, img):
|
||||||
|
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
img = torch.from_numpy(img).to(self.model.device)
|
||||||
|
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||||
|
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||||
|
return img
|
||||||
|
|
||||||
|
def postprocess(self, preds, img, orig_img):
|
||||||
|
preds = ops.non_max_suppression(preds,
|
||||||
|
self.args.conf_thres,
|
||||||
|
self.args.iou_thres,
|
||||||
|
agnostic=self.args.agnostic_nms,
|
||||||
|
max_det=self.args.max_det)
|
||||||
|
|
||||||
|
for i, pred in enumerate(preds):
|
||||||
|
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||||
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||||
|
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def write_results(self, idx, preds, batch):
|
||||||
|
p, im, im0 = batch
|
||||||
|
log_string = ""
|
||||||
|
if len(im.shape) == 3:
|
||||||
|
im = im[None] # expand for batch dim
|
||||||
|
self.seen += 1
|
||||||
|
im0 = im0.copy()
|
||||||
|
if self.webcam: # batch_size >= 1
|
||||||
|
log_string += f'{idx}: '
|
||||||
|
frame = self.dataset.count
|
||||||
|
else:
|
||||||
|
frame = getattr(self.dataset, 'frame', 0)
|
||||||
|
|
||||||
|
self.data_path = p
|
||||||
|
# save_path = str(self.save_dir / p.name) # im.jpg
|
||||||
|
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||||
|
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||||
|
self.annotator = self.get_annotator(im0)
|
||||||
|
|
||||||
|
det = preds[idx]
|
||||||
|
if len(det) == 0:
|
||||||
|
return log_string
|
||||||
|
for c in det[:, 5].unique():
|
||||||
|
n = (det[:, 5] == c).sum() # detections per class
|
||||||
|
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||||
|
|
||||||
|
# write
|
||||||
|
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
|
for *xyxy, conf, cls in reversed(det):
|
||||||
|
if self.args.save_txt: # Write to file
|
||||||
|
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||||
|
line = (cls, *xywh, conf) if self.args.save_conf else (cls, *xywh) # label format
|
||||||
|
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||||
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
|
|
||||||
|
if self.save_img or self.args.save_crop or self.args.view_img: # Add bbox to image
|
||||||
|
c = int(cls) # integer class
|
||||||
|
label = None if self.args.hide_labels else (
|
||||||
|
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||||
|
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||||
|
if self.args.save_crop:
|
||||||
|
imc = im0.copy()
|
||||||
|
save_one_box(xyxy,
|
||||||
|
imc,
|
||||||
|
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',
|
||||||
|
BGR=True)
|
||||||
|
|
||||||
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def predict(cfg):
|
||||||
|
cfg.model = cfg.model or "n.pt"
|
||||||
|
sz = cfg.img_size
|
||||||
|
if type(sz) != int: # recieved listConfig
|
||||||
|
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||||
|
else:
|
||||||
|
cfg.img_size = [sz, sz]
|
||||||
|
predictor = DetectionPredictor(cfg)
|
||||||
|
predictor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predict()
|
@ -63,7 +63,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.seen = 0
|
self.seen = 0
|
||||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
||||||
self.metrics = Metric()
|
self.metrics = Metric()
|
||||||
self.loss = torch.zeros(4, device=self.device)
|
self.loss = torch.zeros(3, device=self.device)
|
||||||
self.jdict = []
|
self.jdict = []
|
||||||
self.stats = []
|
self.stats = []
|
||||||
|
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
|
from ultralytics.yolo.v8.segment.predict import SegmentationPredictor, predict
|
||||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
||||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val
|
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val
|
||||||
|
115
ultralytics/yolo/v8/segment/predict.py
Normal file
115
ultralytics/yolo/v8/segment/predict.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
|
from ultralytics.yolo.utils import ROOT, ops
|
||||||
|
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
from ..detect.predict import DetectionPredictor
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationPredictor(DetectionPredictor):
|
||||||
|
|
||||||
|
def postprocess(self, preds, img, orig_img):
|
||||||
|
masks = []
|
||||||
|
if len(preds) == 2: # eval
|
||||||
|
p, proto, = preds
|
||||||
|
else: # len(3) train
|
||||||
|
p, proto, _ = preds
|
||||||
|
# TODO: filter by classes
|
||||||
|
p = ops.non_max_suppression(p,
|
||||||
|
self.args.conf_thres,
|
||||||
|
self.args.iou_thres,
|
||||||
|
agnostic=self.args.agnostic_nms,
|
||||||
|
max_det=self.args.max_det,
|
||||||
|
nm=32)
|
||||||
|
for i, pred in enumerate(p):
|
||||||
|
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||||
|
if not len(pred):
|
||||||
|
continue
|
||||||
|
if self.args.retina_masks:
|
||||||
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||||
|
masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
|
||||||
|
else:
|
||||||
|
masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
|
||||||
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||||
|
|
||||||
|
return (p, masks)
|
||||||
|
|
||||||
|
def write_results(self, idx, preds, batch):
|
||||||
|
p, im, im0 = batch
|
||||||
|
log_string = ""
|
||||||
|
if len(im.shape) == 3:
|
||||||
|
im = im[None] # expand for batch dim
|
||||||
|
self.seen += 1
|
||||||
|
if self.webcam: # batch_size >= 1
|
||||||
|
log_string += f'{idx}: '
|
||||||
|
frame = self.dataset.count
|
||||||
|
else:
|
||||||
|
frame = getattr(self.dataset, 'frame', 0)
|
||||||
|
|
||||||
|
self.data_path = p
|
||||||
|
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||||
|
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||||
|
self.annotator = self.get_annotator(im0)
|
||||||
|
|
||||||
|
preds, masks = preds
|
||||||
|
det = preds[idx]
|
||||||
|
if len(det) == 0:
|
||||||
|
return log_string
|
||||||
|
# Segments
|
||||||
|
mask = masks[idx]
|
||||||
|
if self.args.save_txt:
|
||||||
|
segments = [
|
||||||
|
ops.scale_segments(im0.shape if self.arg.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
|
||||||
|
for x in reversed(ops.masks2segments(mask))]
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
for c in det[:, 5].unique():
|
||||||
|
n = (det[:, 5] == c).sum() # detections per class
|
||||||
|
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||||
|
|
||||||
|
# Mask plotting
|
||||||
|
self.annotator.masks(
|
||||||
|
mask,
|
||||||
|
colors=[colors(x, True) for x in det[:, 5]],
|
||||||
|
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
|
||||||
|
255 if self.args.retina_masks else im[idx])
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
|
||||||
|
if self.args.save_txt: # Write to file
|
||||||
|
seg = segments[j].reshape(-1) # (n,2) to (n*2)
|
||||||
|
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format
|
||||||
|
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||||
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
|
|
||||||
|
if self.save_img or self.args.save_crop or self.args.view_img:
|
||||||
|
c = int(cls) # integer class
|
||||||
|
label = None if self.args.hide_labels else (
|
||||||
|
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||||
|
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||||
|
# annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
|
||||||
|
if self.args.save_crop:
|
||||||
|
imc = im0.copy()
|
||||||
|
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True)
|
||||||
|
|
||||||
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def predict(cfg):
|
||||||
|
cfg.model = cfg.model or "n.pt"
|
||||||
|
sz = cfg.img_size
|
||||||
|
if type(sz) != int: # recieved listConfig
|
||||||
|
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||||
|
else:
|
||||||
|
cfg.img_size = [sz, sz]
|
||||||
|
predictor = SegmentationPredictor(cfg)
|
||||||
|
predictor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predict()
|
Loading…
x
Reference in New Issue
Block a user