ultralytics 8.1.27 batched tracking fixes (#8842)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-03-12 02:29:41 +08:00 committed by GitHub
parent 3555785167
commit 2ea6b2b889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 31 deletions

View File

@ -301,7 +301,7 @@ def test_predict_callback_and_setup():
def on_predict_batch_end(predictor): def on_predict_batch_end(predictor):
"""Callback function that handles operations at the end of a prediction batch.""" """Callback function that handles operations at the end of a prediction batch."""
path, im0s, _, _ = predictor.batch path, im0s, _ = predictor.batch
im0s = im0s if isinstance(im0s, list) else [im0s] im0s = im0s if isinstance(im0s, list) else [im0s]
bs = [predictor.dataset.bs for _ in range(len(path))] bs = [predictor.dataset.bs for _ in range(len(path))]
predictor.results = zip(predictor.results, im0s, bs) # results is List[batch_size] predictor.results = zip(predictor.results, im0s, bs) # results is List[batch_size]

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.26" __version__ = "8.1.27"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View File

@ -80,8 +80,6 @@ class LoadStreams:
self.imgs = [[] for _ in range(n)] # images self.imgs = [[] for _ in range(n)] # images
self.shape = [[] for _ in range(n)] # image shapes self.shape = [[] for _ in range(n)] # image shapes
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
self.info = [""] * n
self.is_video = [True] * n
for i, s in enumerate(sources): # index, source for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream # Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... " st = f"{i + 1}/{n}: {s}... "
@ -178,7 +176,7 @@ class LoadStreams:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear() x.clear()
return self.sources, images, self.is_video, self.info return self.sources, images, [""] * self.bs
def __len__(self): def __len__(self):
"""Return the length of the sources object.""" """Return the length of the sources object."""
@ -227,6 +225,7 @@ class LoadScreenshots:
self.frame = 0 self.frame = 0
self.sct = mss.mss() self.sct = mss.mss()
self.bs = 1 self.bs = 1
self.fps = 30
# Parse monitor shape # Parse monitor shape
monitor = self.sct.monitors[self.screen] monitor = self.sct.monitors[self.screen]
@ -246,7 +245,7 @@ class LoadScreenshots:
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
self.frame += 1 self.frame += 1
return [str(self.screen)], [im0], [True], [s] # screen, img, is_video, string return [str(self.screen)], [im0], [s] # screen, img, string
class LoadImagesAndVideos: class LoadImagesAndVideos:
@ -298,6 +297,7 @@ class LoadImagesAndVideos:
self.files = images + videos self.files = images + videos
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.ni = ni # number of images
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
self.mode = "image" self.mode = "image"
self.vid_stride = vid_stride # video frame-rate stride self.vid_stride = vid_stride # video frame-rate stride
@ -319,11 +319,11 @@ class LoadImagesAndVideos:
def __next__(self): def __next__(self):
"""Returns the next batch of images or video frames along with their paths and metadata.""" """Returns the next batch of images or video frames along with their paths and metadata."""
paths, imgs, is_video, info = [], [], [], [] paths, imgs, info = [], [], []
while len(imgs) < self.bs: while len(imgs) < self.bs:
if self.count >= self.nf: # end of file list if self.count >= self.nf: # end of file list
if len(imgs) > 0: if len(imgs) > 0:
return paths, imgs, is_video, info # return last partial batch return paths, imgs, info # return last partial batch
else: else:
raise StopIteration raise StopIteration
@ -344,7 +344,6 @@ class LoadImagesAndVideos:
self.frame += 1 self.frame += 1
paths.append(path) paths.append(path)
imgs.append(im0) imgs.append(im0)
is_video.append(True)
info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")
if self.frame == self.frames: # end of video if self.frame == self.frames: # end of video
self.count += 1 self.count += 1
@ -363,16 +362,18 @@ class LoadImagesAndVideos:
raise FileNotFoundError(f"Image Not Found {path}") raise FileNotFoundError(f"Image Not Found {path}")
paths.append(path) paths.append(path)
imgs.append(im0) imgs.append(im0)
is_video.append(False) # no capture object for images
info.append(f"image {self.count + 1}/{self.nf} {path}: ") info.append(f"image {self.count + 1}/{self.nf} {path}: ")
self.count += 1 # move to the next file self.count += 1 # move to the next file
if self.count >= self.ni: # end of image list
break
return paths, imgs, is_video, info return paths, imgs, info
def _new_video(self, path): def _new_video(self, path):
"""Creates a new video capture object for the given path.""" """Creates a new video capture object for the given path."""
self.frame = 0 self.frame = 0
self.cap = cv2.VideoCapture(path) self.cap = cv2.VideoCapture(path)
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
if not self.cap.isOpened(): if not self.cap.isOpened():
raise FileNotFoundError(f"Failed to open video {path}") raise FileNotFoundError(f"Failed to open video {path}")
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
@ -429,7 +430,7 @@ class LoadPilAndNumpy:
if self.count == 1: # loop only once as it's batch inference if self.count == 1: # loop only once as it's batch inference
raise StopIteration raise StopIteration
self.count += 1 self.count += 1
return self.paths, self.im0, [False] * self.bs, [""] * self.bs return self.paths, self.im0, [""] * self.bs
def __iter__(self): def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy.""" """Enables iteration for class LoadPilAndNumpy."""
@ -494,7 +495,7 @@ class LoadTensor:
if self.count == 1: if self.count == 1:
raise StopIteration raise StopIteration
self.count += 1 self.count += 1
return self.paths, self.im0, [False] * self.bs, [""] * self.bs return self.paths, self.im0, [""] * self.bs
def __len__(self): def __len__(self):
"""Returns the batch size.""" """Returns the batch size."""

View File

@ -30,6 +30,7 @@ Usage - formats:
""" """
import platform import platform
import re
import threading import threading
from pathlib import Path from pathlib import Path
@ -236,7 +237,7 @@ class BasePredictor:
self.run_callbacks("on_predict_start") self.run_callbacks("on_predict_start")
for self.batch in self.dataset: for self.batch in self.dataset:
self.run_callbacks("on_predict_batch_start") self.run_callbacks("on_predict_batch_start")
paths, im0s, is_video, s = self.batch paths, im0s, s = self.batch
# Preprocess # Preprocess
with profilers[0]: with profilers[0]:
@ -264,7 +265,7 @@ class BasePredictor:
"postprocess": profilers[2].dt * 1e3 / n, "postprocess": profilers[2].dt * 1e3 / n,
} }
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s[i] += self.write_results(i, Path(paths[i]), im, is_video) s[i] += self.write_results(i, Path(paths[i]), im, s)
# Print batch results # Print batch results
if self.args.verbose: if self.args.verbose:
@ -308,7 +309,7 @@ class BasePredictor:
self.args.half = self.model.fp16 # update half self.args.half = self.model.fp16 # update half
self.model.eval() self.model.eval()
def write_results(self, i, p, im, is_video): def write_results(self, i, p, im, s):
"""Write inference results to a file or directory.""" """Write inference results to a file or directory."""
string = "" # print string string = "" # print string
if len(im.shape) == 3: if len(im.shape) == 3:
@ -317,9 +318,10 @@ class BasePredictor:
string += f"{i}: " string += f"{i}: "
frame = self.dataset.count frame = self.dataset.count
else: else:
frame = getattr(self.dataset, "frame", 0) - len(self.results) + i match = re.search(r"frame (\d+)/", s[i])
frame = int(match.group(1)) if match else None # 0 if frame undetermined
self.txt_path = self.save_dir / "labels" / (p.stem + (f"_{frame}" if is_video[i] else "")) self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
string += "%gx%g " % im.shape[2:] string += "%gx%g " % im.shape[2:]
result = self.results[i] result = self.results[i]
result.save_dir = self.save_dir.__str__() # used in other locations result.save_dir = self.save_dir.__str__() # used in other locations
@ -341,18 +343,19 @@ class BasePredictor:
if self.args.save_crop: if self.args.save_crop:
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
if self.args.show: if self.args.show:
self.show(str(p), is_video[i]) self.show(str(p))
if self.args.save: if self.args.save:
self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame) self.save_predicted_images(str(self.save_dir / p.name), frame)
return string return string
def save_predicted_images(self, save_path="", is_video=False, frame=0): def save_predicted_images(self, save_path="", frame=0):
"""Save video predictions as mp4 at specified path.""" """Save video predictions as mp4 at specified path."""
im = self.plotted_img im = self.plotted_img
# Save videos and streams # Save videos and streams
if is_video: if self.dataset.mode in {"stream", "video"}:
fps = self.dataset.fps if self.dataset.mode == "video" else 30
frames_path = f'{save_path.split(".", 1)[0]}_frames/' frames_path = f'{save_path.split(".", 1)[0]}_frames/'
if save_path not in self.vid_writer: # new video if save_path not in self.vid_writer: # new video
if self.args.save_frames: if self.args.save_frames:
@ -361,7 +364,7 @@ class BasePredictor:
self.vid_writer[save_path] = cv2.VideoWriter( self.vid_writer[save_path] = cv2.VideoWriter(
filename=str(Path(save_path).with_suffix(suffix)), filename=str(Path(save_path).with_suffix(suffix)),
fourcc=cv2.VideoWriter_fourcc(*fourcc), fourcc=cv2.VideoWriter_fourcc(*fourcc),
fps=30, # integer required, floats produce error in MP4 codec fps=fps, # integer required, floats produce error in MP4 codec
frameSize=(im.shape[1], im.shape[0]), # (width, height) frameSize=(im.shape[1], im.shape[0]), # (width, height)
) )
@ -374,7 +377,7 @@ class BasePredictor:
else: else:
cv2.imwrite(save_path, im) cv2.imwrite(save_path, im)
def show(self, p="", is_video=False): def show(self, p=""):
"""Display an image in a window using OpenCV imshow().""" """Display an image in a window using OpenCV imshow()."""
im = self.plotted_img im = self.plotted_img
if platform.system() == "Linux" and p not in self.windows: if platform.system() == "Linux" and p not in self.windows:
@ -382,7 +385,7 @@ class BasePredictor:
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
cv2.imshow(p, im) cv2.imshow(p, im)
cv2.waitKey(1 if is_video else 500) # 1 millisecond cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Runs all registered callbacks for a specific event.""" """Runs all registered callbacks for a specific event."""

View File

@ -38,6 +38,8 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
for _ in range(predictor.dataset.bs): for _ in range(predictor.dataset.bs):
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker) trackers.append(tracker)
if predictor.dataset.mode != "stream": # only need one tracker for other modes.
break
predictor.trackers = trackers predictor.trackers = trackers
predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
@ -50,20 +52,21 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
predictor (object): The predictor object containing the predictions. predictor (object): The predictor object containing the predictions.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
""" """
bs = predictor.dataset.bs
path, im0s = predictor.batch[:2] path, im0s = predictor.batch[:2]
is_obb = predictor.args.task == "obb" is_obb = predictor.args.task == "obb"
for i in range(bs): is_stream = predictor.dataset.mode == "stream"
for i in range(len(im0s)):
tracker = predictor.trackers[i if is_stream else 0]
vid_path = predictor.save_dir / Path(path[i]).name vid_path = predictor.save_dir / Path(path[i]).name
if not persist and predictor.vid_path[i] != vid_path: # new video if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
predictor.trackers[i].reset() tracker.reset()
predictor.vid_path[i] = vid_path predictor.vid_path[i if is_stream else 0] = vid_path
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0: if len(det) == 0:
continue continue
tracks = predictor.trackers[i].update(det, im0s[i]) tracks = tracker.update(det, im0s[i])
if len(tracks) == 0: if len(tracks) == 0:
continue continue
idx = tracks[:, -1].astype(int) idx = tracks[:, -1].astype(int)