From f4b34fc30b6bde1650054279e1e9f205b0bccad6 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Thu, 1 Jun 2023 06:40:10 +0800 Subject: [PATCH] Fix `save_txt` in track mode and add Keypoints and Probs (#2921) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_python.py | 58 +++++------ ultralytics/yolo/engine/results.py | 156 ++++++++++++++++++++++++----- 2 files changed, 156 insertions(+), 58 deletions(-) diff --git a/tests/test_python.py b/tests/test_python.py index 4b64eead..d51f091a 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -207,42 +207,34 @@ def test_predict_callback_and_setup(): print(boxes) -def test_result(): - model = YOLO('yolov8n-pose.pt') - res = model([SOURCE, SOURCE]) - res[0].plot(conf=True, boxes=False) - res[0].plot(pil=True) - res[0] = res[0].cpu().numpy() - print(res[0].path, res[0].keypoints) +def _test_results_api(res): + # General apis except plot + res = res.cpu().numpy() + # res = res.cuda() + res = res.to(device='cpu', dtype=torch.float32) + res.save_txt('label.txt', save_conf=False) + res.save_txt('label.txt', save_conf=True) + res.save_crop('crops/') + res.tojson(normalize=False) + res.tojson(normalize=True) + res.plot(pil=True) + res.plot(conf=True, boxes=False) + res.plot() + print(res.path) + for k in res.keys: + print(getattr(res, k).data) - model = YOLO('yolov8n-seg.pt') - res = model([SOURCE, SOURCE]) - res[0].plot(conf=True, boxes=False, masks=True) - res[0].plot(pil=True) - res[0] = res[0].cpu().numpy() - print(res[0].path, res[0].masks.data) - model = YOLO('yolov8n.pt') - res = model(SOURCE) - res[0].plot(pil=True) - res[0].plot() - res[0] = res[0].cpu().numpy() - print(res[0].path) - - model = YOLO('yolov8n-cls.pt') - res = model(SOURCE) - res[0].plot(probs=False) - res[0].plot(pil=True) - res[0].plot() - res[0] = res[0].cpu().numpy() - print(res[0].path) +def test_results(): + for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt', 'yolov8n-cls.pt']: + model = YOLO(m) + res = model([SOURCE, SOURCE]) + _test_results_api(res[0]) def test_track(): im = cv2.imread(str(SOURCE)) - model = YOLO(MODEL) - seg_model = YOLO('yolov8n-seg.pt') - pose_model = YOLO('yolov8n-pose.pt') - model.track(source=im) - seg_model.track(source=im) - pose_model.track(source=im) + for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt']: + model = YOLO(m) + res = model.track(source=im) + _test_results_api(res[0]) diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index 030edabf..8d1e981d 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -23,7 +23,13 @@ class BaseTensor(SimpleClass): """ def __init__(self, data, orig_shape) -> None: - """Initialize BaseTensor with data and original shape.""" + """Initialize BaseTensor with data and original shape. + + Args: + data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints. + orig_shape (tuple): Original shape of image. + """ + assert isinstance(data, (torch.Tensor, np.ndarray)) self.data = data self.orig_shape = orig_shape @@ -34,19 +40,19 @@ class BaseTensor(SimpleClass): def cpu(self): """Return a copy of the tensor on CPU memory.""" - return self.__class__(self.data.cpu(), self.orig_shape) + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape) def numpy(self): """Return a copy of the tensor as a numpy array.""" - return self.__class__(self.data.numpy(), self.orig_shape) + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape) def cuda(self): """Return a copy of the tensor on GPU memory.""" - return self.__class__(self.data.cuda(), self.orig_shape) + return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape) def to(self, *args, **kwargs): """Return a copy of the tensor with the specified device and dtype.""" - return self.__class__(self.data.to(*args, **kwargs), self.orig_shape) + return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape) def __len__(self): # override len(results) """Return the length of the data tensor.""" @@ -90,8 +96,8 @@ class Results(SimpleClass): self.orig_shape = orig_img.shape[:2] self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks - self.probs = probs if probs is not None else None - self.keypoints = keypoints if keypoints is not None else None + self.probs = Probs(probs) if probs is not None else None + self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image self.names = names self.path = path @@ -229,13 +235,11 @@ class Results(SimpleClass): annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if pred_probs is not None and show_probs: - n5 = min(len(names), 5) - top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices - text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, " + text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, " annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors if keypoints is not None: - for k in reversed(keypoints): + for k in reversed(keypoints.data): annotator.kpts(k, self.orig_shape, kpt_line=kpt_line) return annotator.result() @@ -250,9 +254,7 @@ class Results(SimpleClass): if len(self) == 0: return log_string if probs is not None else f'{log_string}(no detections), ' if probs is not None: - n5 = min(len(self.names), 5) - top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices - log_string += f"{', '.join(f'{self.names[j]} {probs[j]:.2f}' for j in top5i)}, " + log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " if boxes: for c in boxes.cls.unique(): n = (boxes.cls == c).sum() # detections per class @@ -274,9 +276,7 @@ class Results(SimpleClass): texts = [] if probs is not None: # Classify - n5 = min(len(self.names), 5) - top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices - [texts.append(f'{probs[j]:.2f} {self.names[j]}') for j in top5i] + [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5] elif boxes: # Detect/segment/pose for j, d in enumerate(boxes): @@ -286,7 +286,7 @@ class Results(SimpleClass): seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) line = (c, *seg) if kpts is not None: - kpt = (kpts[j][:, :2].cpu() / d.orig_shape[[1, 0]]).reshape(-1).tolist() + kpt = kpts[j].xyn.reshape(-1).tolist() line += (*kpt, ) line += (conf, ) * save_conf + (() if id is None else (id, )) texts.append(('%g ' * len(line)).rstrip() % line) @@ -322,6 +322,10 @@ class Results(SimpleClass): def tojson(self, normalize=False): """Convert the object to JSON format.""" + if self.probs is not None: + LOGGER.warning('Warning: Classify task do not support `tojson` yet.') + return + import json # Create list of detection dictionaries @@ -338,7 +342,7 @@ class Results(SimpleClass): x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()} if self.keypoints is not None: - x, y, visible = self.keypoints[i].cpu().unbind(dim=1) # torch Tensor + x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()} results.append(result) @@ -386,8 +390,7 @@ class Boxes(BaseTensor): assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls super().__init__(boxes, orig_shape) self.is_track = n == 7 - self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \ - else np.asarray(orig_shape) + self.orig_shape = orig_shape @property def xyxy(self): @@ -419,13 +422,19 @@ class Boxes(BaseTensor): @lru_cache(maxsize=2) def xyxyn(self): """Return the boxes in xyxy format normalized by original image size.""" - return self.xyxy / self.orig_shape[[1, 0, 1, 0]] + xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy) + xyxy[..., [0, 2]] /= self.orig_shape[1] + xyxy[..., [1, 3]] /= self.orig_shape[0] + return xyxy @property @lru_cache(maxsize=2) def xywhn(self): """Return the boxes in xywh format normalized by original image size.""" - return self.xywh / self.orig_shape[[1, 0, 1, 0]] + xywh = ops.xyxy2xywh(self.xyxy) + xywh[..., [0, 2]] /= self.orig_shape[1] + xywh[..., [1, 3]] /= self.orig_shape[0] + return xywh @property def boxes(self): @@ -439,11 +448,11 @@ class Masks(BaseTensor): A class for storing and manipulating detection masks. Args: - masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width). + masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width). orig_shape (tuple): Original image size, in the format (height, width). Attributes: - masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width). + masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width). orig_shape (tuple): Original image size, in the format (height, width). Properties: @@ -496,3 +505,100 @@ class Masks(BaseTensor): def pandas(self): """Convert the object to a pandas DataFrame (not yet implemented).""" LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.") + + +class Keypoints(BaseTensor): + """ + A class for storing and manipulating detection keypoints. + + Args: + keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3). + orig_shape (tuple): Original image size, in the format (height, width). + + Attributes: + keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3). + orig_shape (tuple): Original image size, in the format (height, width). + + Properties: + xy (list): A list of keypoints (pixels) which includes x, y keypoints of each detection. + xyn (list): A list of keypoints (normalized) which includes x, y keypoints of each detection. + + Methods: + cpu(): Returns a copy of the keypoints tensor on CPU memory. + numpy(): Returns a copy of the keypoints tensor as a numpy array. + cuda(): Returns a copy of the keypoints tensor on GPU memory. + to(): Returns a copy of the keypoints tensor with the specified device and dtype. + """ + + def __init__(self, keypoints, orig_shape) -> None: + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + super().__init__(keypoints, orig_shape) + self.has_visible = self.data.shape[-1] == 3 + + @property + @lru_cache(maxsize=1) + def xy(self): + return self.data[..., :2] + + @property + @lru_cache(maxsize=1) + def xyn(self): + xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy) + xy[..., 0] /= self.orig_shape[1] + xy[..., 1] /= self.orig_shape[0] + return xy + + @property + @lru_cache(maxsize=1) + def conf(self): + return self.data[..., 3] if self.has_visible else None + + +class Probs(BaseTensor): + """ + A class for storing and manipulating classify predictions. + + Args: + probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class, ). + + Attributes: + probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class). + + Properties: + top5 (list[int]): Top 1 indice. + top1 (int): Top 5 indices. + + Methods: + cpu(): Returns a copy of the probs tensor on CPU memory. + numpy(): Returns a copy of the probs tensor as a numpy array. + cuda(): Returns a copy of the probs tensor on GPU memory. + to(): Returns a copy of the probs tensor with the specified device and dtype. + """ + + def __init__(self, probs, orig_shape=None) -> None: + super().__init__(probs, orig_shape) + + @property + @lru_cache(maxsize=1) + def top5(self): + """Return the indices of top 5.""" + return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy. + + @property + @lru_cache(maxsize=1) + def top1(self): + """Return the indices of top 1.""" + return int(self.data.argmax()) + + @property + @lru_cache(maxsize=1) + def top5conf(self): + """Return the confidences of top 5.""" + return self.data[self.top5] + + @property + @lru_cache(maxsize=1) + def top1conf(self): + """Return the confidences of top 1.""" + return self.data[self.top1]