mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 14:35:40 +08:00 
			
		
		
		
	ultralytics 8.1.27 batched tracking fixes (#8842)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									3555785167
								
							
						
					
					
						commit
						2ea6b2b889
					
				| @ -301,7 +301,7 @@ def test_predict_callback_and_setup(): | |||||||
| 
 | 
 | ||||||
|     def on_predict_batch_end(predictor): |     def on_predict_batch_end(predictor): | ||||||
|         """Callback function that handles operations at the end of a prediction batch.""" |         """Callback function that handles operations at the end of a prediction batch.""" | ||||||
|         path, im0s, _, _ = predictor.batch |         path, im0s, _ = predictor.batch | ||||||
|         im0s = im0s if isinstance(im0s, list) else [im0s] |         im0s = im0s if isinstance(im0s, list) else [im0s] | ||||||
|         bs = [predictor.dataset.bs for _ in range(len(path))] |         bs = [predictor.dataset.bs for _ in range(len(path))] | ||||||
|         predictor.results = zip(predictor.results, im0s, bs)  # results is List[batch_size] |         predictor.results = zip(predictor.results, im0s, bs)  # results is List[batch_size] | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| # Ultralytics YOLO 🚀, AGPL-3.0 license | # Ultralytics YOLO 🚀, AGPL-3.0 license | ||||||
| 
 | 
 | ||||||
| __version__ = "8.1.26" | __version__ = "8.1.27" | ||||||
| 
 | 
 | ||||||
| from ultralytics.data.explorer.explorer import Explorer | from ultralytics.data.explorer.explorer import Explorer | ||||||
| from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld | from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld | ||||||
|  | |||||||
| @ -80,8 +80,6 @@ class LoadStreams: | |||||||
|         self.imgs = [[] for _ in range(n)]  # images |         self.imgs = [[] for _ in range(n)]  # images | ||||||
|         self.shape = [[] for _ in range(n)]  # image shapes |         self.shape = [[] for _ in range(n)]  # image shapes | ||||||
|         self.sources = [ops.clean_str(x) for x in sources]  # clean source names for later |         self.sources = [ops.clean_str(x) for x in sources]  # clean source names for later | ||||||
|         self.info = [""] * n |  | ||||||
|         self.is_video = [True] * n |  | ||||||
|         for i, s in enumerate(sources):  # index, source |         for i, s in enumerate(sources):  # index, source | ||||||
|             # Start thread to read frames from video stream |             # Start thread to read frames from video stream | ||||||
|             st = f"{i + 1}/{n}: {s}... " |             st = f"{i + 1}/{n}: {s}... " | ||||||
| @ -178,7 +176,7 @@ class LoadStreams: | |||||||
|                 images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) |                 images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) | ||||||
|                 x.clear() |                 x.clear() | ||||||
| 
 | 
 | ||||||
|         return self.sources, images, self.is_video, self.info |         return self.sources, images, [""] * self.bs | ||||||
| 
 | 
 | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         """Return the length of the sources object.""" |         """Return the length of the sources object.""" | ||||||
| @ -227,6 +225,7 @@ class LoadScreenshots: | |||||||
|         self.frame = 0 |         self.frame = 0 | ||||||
|         self.sct = mss.mss() |         self.sct = mss.mss() | ||||||
|         self.bs = 1 |         self.bs = 1 | ||||||
|  |         self.fps = 30 | ||||||
| 
 | 
 | ||||||
|         # Parse monitor shape |         # Parse monitor shape | ||||||
|         monitor = self.sct.monitors[self.screen] |         monitor = self.sct.monitors[self.screen] | ||||||
| @ -246,7 +245,7 @@ class LoadScreenshots: | |||||||
|         s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " |         s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " | ||||||
| 
 | 
 | ||||||
|         self.frame += 1 |         self.frame += 1 | ||||||
|         return [str(self.screen)], [im0], [True], [s]  # screen, img, is_video, string |         return [str(self.screen)], [im0], [s]  # screen, img, string | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LoadImagesAndVideos: | class LoadImagesAndVideos: | ||||||
| @ -298,6 +297,7 @@ class LoadImagesAndVideos: | |||||||
| 
 | 
 | ||||||
|         self.files = images + videos |         self.files = images + videos | ||||||
|         self.nf = ni + nv  # number of files |         self.nf = ni + nv  # number of files | ||||||
|  |         self.ni = ni  # number of images | ||||||
|         self.video_flag = [False] * ni + [True] * nv |         self.video_flag = [False] * ni + [True] * nv | ||||||
|         self.mode = "image" |         self.mode = "image" | ||||||
|         self.vid_stride = vid_stride  # video frame-rate stride |         self.vid_stride = vid_stride  # video frame-rate stride | ||||||
| @ -319,11 +319,11 @@ class LoadImagesAndVideos: | |||||||
| 
 | 
 | ||||||
|     def __next__(self): |     def __next__(self): | ||||||
|         """Returns the next batch of images or video frames along with their paths and metadata.""" |         """Returns the next batch of images or video frames along with their paths and metadata.""" | ||||||
|         paths, imgs, is_video, info = [], [], [], [] |         paths, imgs, info = [], [], [] | ||||||
|         while len(imgs) < self.bs: |         while len(imgs) < self.bs: | ||||||
|             if self.count >= self.nf:  # end of file list |             if self.count >= self.nf:  # end of file list | ||||||
|                 if len(imgs) > 0: |                 if len(imgs) > 0: | ||||||
|                     return paths, imgs, is_video, info  # return last partial batch |                     return paths, imgs, info  # return last partial batch | ||||||
|                 else: |                 else: | ||||||
|                     raise StopIteration |                     raise StopIteration | ||||||
| 
 | 
 | ||||||
| @ -344,7 +344,6 @@ class LoadImagesAndVideos: | |||||||
|                         self.frame += 1 |                         self.frame += 1 | ||||||
|                         paths.append(path) |                         paths.append(path) | ||||||
|                         imgs.append(im0) |                         imgs.append(im0) | ||||||
|                         is_video.append(True) |  | ||||||
|                         info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") |                         info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") | ||||||
|                         if self.frame == self.frames:  # end of video |                         if self.frame == self.frames:  # end of video | ||||||
|                             self.count += 1 |                             self.count += 1 | ||||||
| @ -363,16 +362,18 @@ class LoadImagesAndVideos: | |||||||
|                     raise FileNotFoundError(f"Image Not Found {path}") |                     raise FileNotFoundError(f"Image Not Found {path}") | ||||||
|                 paths.append(path) |                 paths.append(path) | ||||||
|                 imgs.append(im0) |                 imgs.append(im0) | ||||||
|                 is_video.append(False)  # no capture object for images |  | ||||||
|                 info.append(f"image {self.count + 1}/{self.nf} {path}: ") |                 info.append(f"image {self.count + 1}/{self.nf} {path}: ") | ||||||
|                 self.count += 1  # move to the next file |                 self.count += 1  # move to the next file | ||||||
|  |                 if self.count >= self.ni:  # end of image list | ||||||
|  |                     break | ||||||
| 
 | 
 | ||||||
|         return paths, imgs, is_video, info |         return paths, imgs, info | ||||||
| 
 | 
 | ||||||
|     def _new_video(self, path): |     def _new_video(self, path): | ||||||
|         """Creates a new video capture object for the given path.""" |         """Creates a new video capture object for the given path.""" | ||||||
|         self.frame = 0 |         self.frame = 0 | ||||||
|         self.cap = cv2.VideoCapture(path) |         self.cap = cv2.VideoCapture(path) | ||||||
|  |         self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) | ||||||
|         if not self.cap.isOpened(): |         if not self.cap.isOpened(): | ||||||
|             raise FileNotFoundError(f"Failed to open video {path}") |             raise FileNotFoundError(f"Failed to open video {path}") | ||||||
|         self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) |         self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) | ||||||
| @ -429,7 +430,7 @@ class LoadPilAndNumpy: | |||||||
|         if self.count == 1:  # loop only once as it's batch inference |         if self.count == 1:  # loop only once as it's batch inference | ||||||
|             raise StopIteration |             raise StopIteration | ||||||
|         self.count += 1 |         self.count += 1 | ||||||
|         return self.paths, self.im0, [False] * self.bs, [""] * self.bs |         return self.paths, self.im0, [""] * self.bs | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         """Enables iteration for class LoadPilAndNumpy.""" |         """Enables iteration for class LoadPilAndNumpy.""" | ||||||
| @ -494,7 +495,7 @@ class LoadTensor: | |||||||
|         if self.count == 1: |         if self.count == 1: | ||||||
|             raise StopIteration |             raise StopIteration | ||||||
|         self.count += 1 |         self.count += 1 | ||||||
|         return self.paths, self.im0, [False] * self.bs, [""] * self.bs |         return self.paths, self.im0, [""] * self.bs | ||||||
| 
 | 
 | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         """Returns the batch size.""" |         """Returns the batch size.""" | ||||||
|  | |||||||
| @ -30,6 +30,7 @@ Usage - formats: | |||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| import platform | import platform | ||||||
|  | import re | ||||||
| import threading | import threading | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| @ -236,7 +237,7 @@ class BasePredictor: | |||||||
|             self.run_callbacks("on_predict_start") |             self.run_callbacks("on_predict_start") | ||||||
|             for self.batch in self.dataset: |             for self.batch in self.dataset: | ||||||
|                 self.run_callbacks("on_predict_batch_start") |                 self.run_callbacks("on_predict_batch_start") | ||||||
|                 paths, im0s, is_video, s = self.batch |                 paths, im0s, s = self.batch | ||||||
| 
 | 
 | ||||||
|                 # Preprocess |                 # Preprocess | ||||||
|                 with profilers[0]: |                 with profilers[0]: | ||||||
| @ -264,7 +265,7 @@ class BasePredictor: | |||||||
|                         "postprocess": profilers[2].dt * 1e3 / n, |                         "postprocess": profilers[2].dt * 1e3 / n, | ||||||
|                     } |                     } | ||||||
|                     if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: |                     if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: | ||||||
|                         s[i] += self.write_results(i, Path(paths[i]), im, is_video) |                         s[i] += self.write_results(i, Path(paths[i]), im, s) | ||||||
| 
 | 
 | ||||||
|                 # Print batch results |                 # Print batch results | ||||||
|                 if self.args.verbose: |                 if self.args.verbose: | ||||||
| @ -308,7 +309,7 @@ class BasePredictor: | |||||||
|         self.args.half = self.model.fp16  # update half |         self.args.half = self.model.fp16  # update half | ||||||
|         self.model.eval() |         self.model.eval() | ||||||
| 
 | 
 | ||||||
|     def write_results(self, i, p, im, is_video): |     def write_results(self, i, p, im, s): | ||||||
|         """Write inference results to a file or directory.""" |         """Write inference results to a file or directory.""" | ||||||
|         string = ""  # print string |         string = ""  # print string | ||||||
|         if len(im.shape) == 3: |         if len(im.shape) == 3: | ||||||
| @ -317,9 +318,10 @@ class BasePredictor: | |||||||
|             string += f"{i}: " |             string += f"{i}: " | ||||||
|             frame = self.dataset.count |             frame = self.dataset.count | ||||||
|         else: |         else: | ||||||
|             frame = getattr(self.dataset, "frame", 0) - len(self.results) + i |             match = re.search(r"frame (\d+)/", s[i]) | ||||||
|  |             frame = int(match.group(1)) if match else None  # 0 if frame undetermined | ||||||
| 
 | 
 | ||||||
|         self.txt_path = self.save_dir / "labels" / (p.stem + (f"_{frame}" if is_video[i] else "")) |         self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) | ||||||
|         string += "%gx%g " % im.shape[2:] |         string += "%gx%g " % im.shape[2:] | ||||||
|         result = self.results[i] |         result = self.results[i] | ||||||
|         result.save_dir = self.save_dir.__str__()  # used in other locations |         result.save_dir = self.save_dir.__str__()  # used in other locations | ||||||
| @ -341,18 +343,19 @@ class BasePredictor: | |||||||
|         if self.args.save_crop: |         if self.args.save_crop: | ||||||
|             result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) |             result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) | ||||||
|         if self.args.show: |         if self.args.show: | ||||||
|             self.show(str(p), is_video[i]) |             self.show(str(p)) | ||||||
|         if self.args.save: |         if self.args.save: | ||||||
|             self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame) |             self.save_predicted_images(str(self.save_dir / p.name), frame) | ||||||
| 
 | 
 | ||||||
|         return string |         return string | ||||||
| 
 | 
 | ||||||
|     def save_predicted_images(self, save_path="", is_video=False, frame=0): |     def save_predicted_images(self, save_path="", frame=0): | ||||||
|         """Save video predictions as mp4 at specified path.""" |         """Save video predictions as mp4 at specified path.""" | ||||||
|         im = self.plotted_img |         im = self.plotted_img | ||||||
| 
 | 
 | ||||||
|         # Save videos and streams |         # Save videos and streams | ||||||
|         if is_video: |         if self.dataset.mode in {"stream", "video"}: | ||||||
|  |             fps = self.dataset.fps if self.dataset.mode == "video" else 30 | ||||||
|             frames_path = f'{save_path.split(".", 1)[0]}_frames/' |             frames_path = f'{save_path.split(".", 1)[0]}_frames/' | ||||||
|             if save_path not in self.vid_writer:  # new video |             if save_path not in self.vid_writer:  # new video | ||||||
|                 if self.args.save_frames: |                 if self.args.save_frames: | ||||||
| @ -361,7 +364,7 @@ class BasePredictor: | |||||||
|                 self.vid_writer[save_path] = cv2.VideoWriter( |                 self.vid_writer[save_path] = cv2.VideoWriter( | ||||||
|                     filename=str(Path(save_path).with_suffix(suffix)), |                     filename=str(Path(save_path).with_suffix(suffix)), | ||||||
|                     fourcc=cv2.VideoWriter_fourcc(*fourcc), |                     fourcc=cv2.VideoWriter_fourcc(*fourcc), | ||||||
|                     fps=30,  # integer required, floats produce error in MP4 codec |                     fps=fps,  # integer required, floats produce error in MP4 codec | ||||||
|                     frameSize=(im.shape[1], im.shape[0]),  # (width, height) |                     frameSize=(im.shape[1], im.shape[0]),  # (width, height) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -374,7 +377,7 @@ class BasePredictor: | |||||||
|         else: |         else: | ||||||
|             cv2.imwrite(save_path, im) |             cv2.imwrite(save_path, im) | ||||||
| 
 | 
 | ||||||
|     def show(self, p="", is_video=False): |     def show(self, p=""): | ||||||
|         """Display an image in a window using OpenCV imshow().""" |         """Display an image in a window using OpenCV imshow().""" | ||||||
|         im = self.plotted_img |         im = self.plotted_img | ||||||
|         if platform.system() == "Linux" and p not in self.windows: |         if platform.system() == "Linux" and p not in self.windows: | ||||||
| @ -382,7 +385,7 @@ class BasePredictor: | |||||||
|             cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux) |             cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux) | ||||||
|             cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height) |             cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height) | ||||||
|         cv2.imshow(p, im) |         cv2.imshow(p, im) | ||||||
|         cv2.waitKey(1 if is_video else 500)  # 1 millisecond |         cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 millisecond | ||||||
| 
 | 
 | ||||||
|     def run_callbacks(self, event: str): |     def run_callbacks(self, event: str): | ||||||
|         """Runs all registered callbacks for a specific event.""" |         """Runs all registered callbacks for a specific event.""" | ||||||
|  | |||||||
| @ -38,6 +38,8 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: | |||||||
|     for _ in range(predictor.dataset.bs): |     for _ in range(predictor.dataset.bs): | ||||||
|         tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) |         tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) | ||||||
|         trackers.append(tracker) |         trackers.append(tracker) | ||||||
|  |         if predictor.dataset.mode != "stream":  # only need one tracker for other modes. | ||||||
|  |             break | ||||||
|     predictor.trackers = trackers |     predictor.trackers = trackers | ||||||
|     predictor.vid_path = [None] * predictor.dataset.bs  # for determining when to reset tracker on new video |     predictor.vid_path = [None] * predictor.dataset.bs  # for determining when to reset tracker on new video | ||||||
| 
 | 
 | ||||||
| @ -50,20 +52,21 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None | |||||||
|         predictor (object): The predictor object containing the predictions. |         predictor (object): The predictor object containing the predictions. | ||||||
|         persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. |         persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. | ||||||
|     """ |     """ | ||||||
|     bs = predictor.dataset.bs |  | ||||||
|     path, im0s = predictor.batch[:2] |     path, im0s = predictor.batch[:2] | ||||||
| 
 | 
 | ||||||
|     is_obb = predictor.args.task == "obb" |     is_obb = predictor.args.task == "obb" | ||||||
|     for i in range(bs): |     is_stream = predictor.dataset.mode == "stream" | ||||||
|  |     for i in range(len(im0s)): | ||||||
|  |         tracker = predictor.trackers[i if is_stream else 0] | ||||||
|         vid_path = predictor.save_dir / Path(path[i]).name |         vid_path = predictor.save_dir / Path(path[i]).name | ||||||
|         if not persist and predictor.vid_path[i] != vid_path:  # new video |         if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: | ||||||
|             predictor.trackers[i].reset() |             tracker.reset() | ||||||
|             predictor.vid_path[i] = vid_path |             predictor.vid_path[i if is_stream else 0] = vid_path | ||||||
| 
 | 
 | ||||||
|         det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() |         det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() | ||||||
|         if len(det) == 0: |         if len(det) == 0: | ||||||
|             continue |             continue | ||||||
|         tracks = predictor.trackers[i].update(det, im0s[i]) |         tracks = tracker.update(det, im0s[i]) | ||||||
|         if len(tracks) == 0: |         if len(tracks) == 0: | ||||||
|             continue |             continue | ||||||
|         idx = tracks[:, -1].astype(int) |         idx = tracks[:, -1].astype(int) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Laughing
						Laughing