diff --git a/.gitignore b/.gitignore index 64badb1c..0854267a 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,7 @@ MANIFEST # PyInstaller # Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. +# before PyInstaller builds the exe, so as to inject date/other info into it. *.manifest *.spec diff --git a/docs/en/reference/data/loaders.md b/docs/en/reference/data/loaders.md index 3ba4c162..99d7749a 100644 --- a/docs/en/reference/data/loaders.md +++ b/docs/en/reference/data/loaders.md @@ -23,7 +23,7 @@ keywords: Ultralytics, data loaders, LoadStreams, LoadImages, LoadTensor, YOLO, <br><br> -## ::: ultralytics.data.loaders.LoadImages +## ::: ultralytics.data.loaders.LoadImagesAndVideos <br><br> diff --git a/docs/en/reference/utils/files.md b/docs/en/reference/utils/files.md index 586373b1..e9bd16dd 100644 --- a/docs/en/reference/utils/files.md +++ b/docs/en/reference/utils/files.md @@ -38,3 +38,7 @@ keywords: Ultralytics, utility functions, file operations, working directory, fi ## ::: ultralytics.utils.files.get_latest_run <br><br> + +## ::: ultralytics.utils.files.update_models + +<br><br> diff --git a/tests/test_python.py b/tests/test_python.py index 0144ee8f..9450fb09 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -8,6 +8,7 @@ import cv2 import numpy as np import pytest import torch +import yaml from PIL import Image from torchvision.transforms import ToTensor @@ -169,8 +170,6 @@ def test_track_stream(): Note imgsz=160 required for tracking for higher confidence and better matches """ - import yaml - video_url = "https://ultralytics.com/assets/decelera_portrait_min.mov" model = YOLO(MODEL) model.track(video_url, imgsz=160, tracker="bytetrack.yaml") diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a7278b87..d02f1311 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.25" +__version__ = "8.1.26" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 8a2de594..d9d73758 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -396,7 +396,7 @@ def handle_yolo_settings(args: List[str]) -> None: def handle_explorer(): """Open the Ultralytics Explorer GUI.""" checks.check_requirements("streamlit") - LOGGER.info(f"💡 Loading Explorer dashboard...") + LOGGER.info("💡 Loading Explorer dashboard...") subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]) diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index c441ee76..37c5fa41 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -11,7 +11,7 @@ from torch.utils.data import dataloader, distributed from ultralytics.data.loaders import ( LOADERS, - LoadImages, + LoadImagesAndVideos, LoadPilAndNumpy, LoadScreenshots, LoadStreams, @@ -150,34 +150,35 @@ def check_source(source): return source, webcam, screenshot, from_img, in_memory, tensor -def load_inference_source(source=None, vid_stride=1, buffer=False): +def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): """ Loads an inference source for object detection and applies necessary transformations. Args: source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. + batch (int, optional): Batch size for dataloaders. Default is 1. vid_stride (int, optional): The frame interval for video sources. Default is 1. buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. Returns: dataset (Dataset): A dataset object for the specified input source. """ - source, webcam, screenshot, from_img, in_memory, tensor = check_source(source) - source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor) + source, stream, screenshot, from_img, in_memory, tensor = check_source(source) + source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor) # Dataloader if tensor: dataset = LoadTensor(source) elif in_memory: dataset = source - elif webcam: + elif stream: dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer) elif screenshot: dataset = LoadScreenshots(source) elif from_img: dataset = LoadPilAndNumpy(source) else: - dataset = LoadImages(source, vid_stride=vid_stride) + dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride) # Attach source types to the dataset setattr(dataset, "source_type", source_type) diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index 6faf90c3..a0876432 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -24,7 +24,7 @@ from ultralytics.utils.checks import check_requirements class SourceTypes: """Class to represent various types of input sources for predictions.""" - webcam: bool = False + stream: bool = False screenshot: bool = False from_img: bool = False tensor: bool = False @@ -32,9 +32,7 @@ class SourceTypes: class LoadStreams: """ - Stream Loader for various types of video streams. - - Suitable for use with `yolo predict source='rtsp://example.com/media.mp4'`, supports RTSP, RTMP, HTTP, and TCP streams. + Stream Loader for various types of video streams, Supports RTSP, RTMP, HTTP, and TCP streams. Attributes: sources (str): The source input paths or URLs for the video streams. @@ -57,6 +55,11 @@ class LoadStreams: __iter__: Returns an iterator object for the class. __next__: Returns source paths, transformed, and original images for processing. __len__: Return the length of the sources object. + + Example: + ```bash + yolo predict source='rtsp://example.com/media.mp4' + ``` """ def __init__(self, sources="file.streams", vid_stride=1, buffer=False): @@ -69,6 +72,7 @@ class LoadStreams: sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] n = len(sources) + self.bs = n self.fps = [0] * n # frames per second self.frames = [0] * n self.threads = [None] * n @@ -76,6 +80,8 @@ class LoadStreams: self.imgs = [[] for _ in range(n)] # images self.shape = [[] for _ in range(n)] # image shapes self.sources = [ops.clean_str(x) for x in sources] # clean source names for later + self.info = [""] * n + self.is_video = [True] * n for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream st = f"{i + 1}/{n}: {s}... " @@ -109,9 +115,6 @@ class LoadStreams: self.threads[i].start() LOGGER.info("") # newline - # Check for common shapes - self.bs = self.__len__() - def update(self, i, cap, stream): """Read stream `i` frames in daemon thread.""" n, f = 0, self.frames[i] # frame number, frame array @@ -175,11 +178,11 @@ class LoadStreams: images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) x.clear() - return self.sources, images, None, "" + return self.sources, images, self.is_video, self.info def __len__(self): """Return the length of the sources object.""" - return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years + return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years class LoadScreenshots: @@ -243,10 +246,10 @@ class LoadScreenshots: s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " self.frame += 1 - return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string + return [str(self.screen)], [im0], [True], [s] # screen, img, is_video, string -class LoadImages: +class LoadImagesAndVideos: """ YOLOv8 image/video dataloader. @@ -269,7 +272,7 @@ class LoadImages: _new_video(path): Create a new cv2.VideoCapture object for a given video path. """ - def __init__(self, path, vid_stride=1): + def __init__(self, path, batch=1, vid_stride=1): """Initialize the Dataloader and raise FileNotFoundError if file not found.""" parent = None if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line @@ -298,7 +301,7 @@ class LoadImages: self.video_flag = [False] * ni + [True] * nv self.mode = "image" self.vid_stride = vid_stride # video frame-rate stride - self.bs = 1 + self.bs = batch if any(videos): self._new_video(videos[0]) # new video else: @@ -315,49 +318,68 @@ class LoadImages: return self def __next__(self): - """Return next image, path and metadata from dataset.""" - if self.count == self.nf: - raise StopIteration - path = self.files[self.count] - - if self.video_flag[self.count]: - # Read video - self.mode = "video" - for _ in range(self.vid_stride): - self.cap.grab() - success, im0 = self.cap.retrieve() - while not success: - self.count += 1 - self.cap.release() - if self.count == self.nf: # last video + """Returns the next batch of images or video frames along with their paths and metadata.""" + paths, imgs, is_video, info = [], [], [], [] + while len(imgs) < self.bs: + if self.count >= self.nf: # end of file list + if len(imgs) > 0: + return paths, imgs, is_video, info # return last partial batch + else: raise StopIteration - path = self.files[self.count] - self._new_video(path) - success, im0 = self.cap.read() - self.frame += 1 - # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False - s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: " + path = self.files[self.count] + if self.video_flag[self.count]: + self.mode = "video" + if not self.cap or not self.cap.isOpened(): + self._new_video(path) - else: - # Read image - self.count += 1 - im0 = cv2.imread(path) # BGR - if im0 is None: - raise FileNotFoundError(f"Image Not Found {path}") - s = f"image {self.count}/{self.nf} {path}: " + for _ in range(self.vid_stride): + success = self.cap.grab() + if not success: + break # end of video or failure - return [path], [im0], self.cap, s + if success: + success, im0 = self.cap.retrieve() + if success: + self.frame += 1 + paths.append(path) + imgs.append(im0) + is_video.append(True) + info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") + if self.frame == self.frames: # end of video + self.count += 1 + self.cap.release() + else: + # Move to the next file if the current video ended or failed to open + self.count += 1 + if self.cap: + self.cap.release() + if self.count < self.nf: + self._new_video(self.files[self.count]) + else: + self.mode = "image" + im0 = cv2.imread(path) # BGR + if im0 is None: + raise FileNotFoundError(f"Image Not Found {path}") + paths.append(path) + imgs.append(im0) + is_video.append(False) # no capture object for images + info.append(f"image {self.count + 1}/{self.nf} {path}: ") + self.count += 1 # move to the next file + + return paths, imgs, is_video, info def _new_video(self, path): - """Create a new video capture object.""" + """Creates a new video capture object for the given path.""" self.frame = 0 self.cap = cv2.VideoCapture(path) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) def __len__(self): - """Returns the number of files in the object.""" - return self.nf # number of files + """Returns the number of batches in the object.""" + return math.ceil(self.nf / self.bs) # number of files class LoadPilAndNumpy: @@ -373,7 +395,6 @@ class LoadPilAndNumpy: im0 (list): List of images stored as Numpy arrays. mode (str): Type of data being processed, defaults to 'image'. bs (int): Batch size, equivalent to the length of `im0`. - count (int): Counter for iteration, initialized at 0 during `__iter__()`. Methods: _single_check(im): Validate and format a single image to a Numpy array. @@ -386,7 +407,6 @@ class LoadPilAndNumpy: self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] self.im0 = [self._single_check(im) for im in im0] self.mode = "image" - # Generate fake paths self.bs = len(self.im0) @staticmethod @@ -409,7 +429,7 @@ class LoadPilAndNumpy: if self.count == 1: # loop only once as it's batch inference raise StopIteration self.count += 1 - return self.paths, self.im0, None, "" + return self.paths, self.im0, [False] * self.bs, [""] * self.bs def __iter__(self): """Enables iteration for class LoadPilAndNumpy.""" @@ -474,7 +494,7 @@ class LoadTensor: if self.count == 1: raise StopIteration self.count += 1 - return self.paths, self.im0, None, "" + return self.paths, self.im0, [False] * self.bs, [""] * self.bs def __len__(self): """Returns the batch size.""" @@ -498,9 +518,6 @@ def autocast_list(source): return files -LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple - - def get_best_youtube_url(url, use_pafy=True): """ Retrieves the URL of the best quality MP4 video stream from a given YouTube video. @@ -531,3 +548,7 @@ def get_best_youtube_url(url, use_pafy=True): good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": return f.get("url") + + +# Define constants +LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots) diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 9debfe19..e32c3e45 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -423,7 +423,7 @@ class Model(nn.Module): x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track") ) - custom = {"conf": 0.25, "save": is_cli, "mode": "predict"} # method defaults + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults args = {**self.overrides, **custom, **kwargs} # highest priority args on the right prompts = args.pop("prompts", None) # for SAM-type models @@ -474,6 +474,7 @@ class Model(nn.Module): register_tracker(self, persist) kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos kwargs["mode"] = "track" return self.predict(source=source, stream=stream, **kwargs) diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index e925902f..2282eed6 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -73,9 +73,7 @@ class BasePredictor: data (dict): Data configuration. device (torch.device): Device used for prediction. dataset (Dataset): Dataset used for prediction. - vid_path (str): Path to video file. - vid_writer (cv2.VideoWriter): Video writer for saving video output. - data_path (str): Path to data. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): @@ -100,10 +98,11 @@ class BasePredictor: self.imgsz = None self.device = None self.dataset = None - self.vid_path, self.vid_writer, self.vid_frame = None, None, None + self.vid_writer = {} # dict of {save_path: video_writer, ...} self.plotted_img = None - self.data_path = None self.source_type = None + self.seen = 0 + self.windows = [] self.batch = None self.results = None self.transforms = None @@ -155,44 +154,6 @@ class BasePredictor: letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) return [letterbox(image=x) for x in im] - def write_results(self, idx, results, batch): - """Write inference results to a file or directory.""" - p, im, _ = batch - log_string = "" - if len(im.shape) == 3: - im = im[None] # expand for batch dim - if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 - log_string += f"{idx}: " - frame = self.dataset.count - else: - frame = getattr(self.dataset, "frame", 0) - self.data_path = p - self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}") - log_string += "%gx%g " % im.shape[2:] # print string - result = results[idx] - log_string += result.verbose() - - if self.args.save or self.args.show: # Add bbox to image - plot_args = { - "line_width": self.args.line_width, - "boxes": self.args.show_boxes, - "conf": self.args.show_conf, - "labels": self.args.show_labels, - } - if not self.args.retina_masks: - plot_args["im_gpu"] = im[idx] - self.plotted_img = result.plot(**plot_args) - # Write - if self.args.save_txt: - result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) - if self.args.save_crop: - result.save_crop( - save_dir=self.save_dir / "crops", - file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"), - ) - - return log_string - def postprocess(self, preds, img, orig_imgs): """Post-processes predictions for an image and returns them.""" return preds @@ -228,18 +189,20 @@ class BasePredictor: else None ) self.dataset = load_inference_source( - source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, ) self.source_type = self.dataset.source_type if not getattr(self, "stream", True) and ( - self.dataset.mode == "stream" # streams - or len(self.dataset) > 1000 # images + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images or any(getattr(self.dataset, "video_flag", [False])) ): # videos LOGGER.warning(STREAM_WARNING) - self.vid_path = [None] * self.dataset.bs - self.vid_writer = [None] * self.dataset.bs - self.vid_frame = [None] * self.dataset.bs + self.vid_writer = {} @smart_inference_mode() def stream_inference(self, source=None, model=None, *args, **kwargs): @@ -271,10 +234,9 @@ class BasePredictor: ops.Profile(device=self.device), ) self.run_callbacks("on_predict_start") - for batch in self.dataset: + for self.batch in self.dataset: self.run_callbacks("on_predict_batch_start") - self.batch = batch - path, im0s, vid_cap, s = batch + paths, im0s, is_video, s = self.batch # Preprocess with profilers[0]: @@ -290,8 +252,8 @@ class BasePredictor: # Postprocess with profilers[2]: self.results = self.postprocess(preds, im, im0s) - self.run_callbacks("on_predict_postprocess_end") + # Visualize, save, write results n = len(im0s) for i in range(n): @@ -301,41 +263,32 @@ class BasePredictor: "inference": profilers[1].dt * 1e3 / n, "postprocess": profilers[2].dt * 1e3 / n, } - p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy() - p = Path(p) - if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: - s += self.write_results(i, self.results, (p, im, im0)) - if self.args.save or self.args.save_txt: - self.results[i].save_dir = self.save_dir.__str__() - if self.args.show and self.plotted_img is not None: - self.show(p) - if self.args.save and self.plotted_img is not None: - self.save_preds(vid_cap, i, str(self.save_dir / p.name)) + s[i] += self.write_results(i, Path(paths[i]), im, is_video) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) self.run_callbacks("on_predict_batch_end") yield from self.results - # Print time (inference-only) - if self.args.verbose: - LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms") - # Release assets - if isinstance(self.vid_writer[-1], cv2.VideoWriter): - self.vid_writer[-1].release() # release final video writer + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() - # Print results + # Print final results if self.args.verbose and self.seen: t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image LOGGER.info( f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " - f"{(1, 3, *im.shape[2:])}" % t + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t ) if self.args.save or self.args.save_txt or self.args.save_crop: nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") - self.run_callbacks("on_predict_end") def setup_model(self, model, verbose=True): @@ -354,48 +307,81 @@ class BasePredictor: self.args.half = self.model.fp16 # update half self.model.eval() - def show(self, p): - """Display an image in a window using OpenCV imshow().""" - im0 = self.plotted_img - if platform.system() == "Linux" and p not in self.windows: - self.windows.append(p) - cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) - cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) - cv2.imshow(str(p), im0) - cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond + def write_results(self, i, p, im, is_video): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + frame = getattr(self.dataset, "frame", 0) - len(self.results) + i - def save_preds(self, vid_cap, idx, save_path): + self.txt_path = self.save_dir / "labels" / (p.stem + f"_{frame}" if is_video[i] else "") + string += "%gx%g " % im.shape[2:] + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += result.verbose() + f"{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p), is_video[i]) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame) + + return string + + def save_predicted_images(self, save_path="", is_video=False, frame=0): """Save video predictions as mp4 at specified path.""" - im0 = self.plotted_img - # Save imgs - if self.dataset.mode == "image": - cv2.imwrite(save_path, im0) - else: # 'video' or 'stream' + im = self.plotted_img + + # Save videos and streams + if is_video: frames_path = f'{save_path.split(".", 1)[0]}_frames/' - if self.vid_path[idx] != save_path: # new video - self.vid_path[idx] = save_path + if save_path not in self.vid_writer: # new video if self.args.save_frames: Path(frames_path).mkdir(parents=True, exist_ok=True) - self.vid_frame[idx] = 0 - if isinstance(self.vid_writer[idx], cv2.VideoWriter): - self.vid_writer[idx].release() # release previous video writer - if vid_cap: # video - fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec - w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - else: # stream - fps, w, h = 30, im0.shape[1], im0.shape[0] suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") - self.vid_writer[idx] = cv2.VideoWriter( - str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h) + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=30, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) ) - # Write video - self.vid_writer[idx].write(im0) - # Write frame + # Save video + self.vid_writer[save_path].write(im) if self.args.save_frames: - cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0) - self.vid_frame[idx] += 1 + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(save_path, im) + + def show(self, p="", is_video=False): + """Display an image in a window using OpenCV imshow().""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(1 if is_video else 500) # 1 millisecond def run_callbacks(self, event: str): """Runs all registered callbacks for a specific event.""" diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index c80c54da..6c7d5ef0 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -39,6 +39,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) trackers.append(tracker) predictor.trackers = trackers + predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: @@ -54,8 +55,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None is_obb = predictor.args.task == "obb" for i in range(bs): - if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video + vid_path = predictor.save_dir / Path(path[i]).name + if not persist and predictor.vid_path[i] != vid_path: # new video predictor.trackers[i].reset() + predictor.vid_path[i] = vid_path det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() if len(det) == 0: