mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix save_txt
in track mode and add Keypoints and Probs (#2921)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6c65934b55
commit
f4b34fc30b
@ -207,42 +207,34 @@ def test_predict_callback_and_setup():
|
|||||||
print(boxes)
|
print(boxes)
|
||||||
|
|
||||||
|
|
||||||
def test_result():
|
def _test_results_api(res):
|
||||||
model = YOLO('yolov8n-pose.pt')
|
# General apis except plot
|
||||||
res = model([SOURCE, SOURCE])
|
res = res.cpu().numpy()
|
||||||
res[0].plot(conf=True, boxes=False)
|
# res = res.cuda()
|
||||||
res[0].plot(pil=True)
|
res = res.to(device='cpu', dtype=torch.float32)
|
||||||
res[0] = res[0].cpu().numpy()
|
res.save_txt('label.txt', save_conf=False)
|
||||||
print(res[0].path, res[0].keypoints)
|
res.save_txt('label.txt', save_conf=True)
|
||||||
|
res.save_crop('crops/')
|
||||||
|
res.tojson(normalize=False)
|
||||||
|
res.tojson(normalize=True)
|
||||||
|
res.plot(pil=True)
|
||||||
|
res.plot(conf=True, boxes=False)
|
||||||
|
res.plot()
|
||||||
|
print(res.path)
|
||||||
|
for k in res.keys:
|
||||||
|
print(getattr(res, k).data)
|
||||||
|
|
||||||
model = YOLO('yolov8n-seg.pt')
|
|
||||||
res = model([SOURCE, SOURCE])
|
|
||||||
res[0].plot(conf=True, boxes=False, masks=True)
|
|
||||||
res[0].plot(pil=True)
|
|
||||||
res[0] = res[0].cpu().numpy()
|
|
||||||
print(res[0].path, res[0].masks.data)
|
|
||||||
|
|
||||||
model = YOLO('yolov8n.pt')
|
def test_results():
|
||||||
res = model(SOURCE)
|
for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt', 'yolov8n-cls.pt']:
|
||||||
res[0].plot(pil=True)
|
model = YOLO(m)
|
||||||
res[0].plot()
|
res = model([SOURCE, SOURCE])
|
||||||
res[0] = res[0].cpu().numpy()
|
_test_results_api(res[0])
|
||||||
print(res[0].path)
|
|
||||||
|
|
||||||
model = YOLO('yolov8n-cls.pt')
|
|
||||||
res = model(SOURCE)
|
|
||||||
res[0].plot(probs=False)
|
|
||||||
res[0].plot(pil=True)
|
|
||||||
res[0].plot()
|
|
||||||
res[0] = res[0].cpu().numpy()
|
|
||||||
print(res[0].path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_track():
|
def test_track():
|
||||||
im = cv2.imread(str(SOURCE))
|
im = cv2.imread(str(SOURCE))
|
||||||
model = YOLO(MODEL)
|
for m in ['yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt']:
|
||||||
seg_model = YOLO('yolov8n-seg.pt')
|
model = YOLO(m)
|
||||||
pose_model = YOLO('yolov8n-pose.pt')
|
res = model.track(source=im)
|
||||||
model.track(source=im)
|
_test_results_api(res[0])
|
||||||
seg_model.track(source=im)
|
|
||||||
pose_model.track(source=im)
|
|
||||||
|
@ -23,7 +23,13 @@ class BaseTensor(SimpleClass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data, orig_shape) -> None:
|
def __init__(self, data, orig_shape) -> None:
|
||||||
"""Initialize BaseTensor with data and original shape."""
|
"""Initialize BaseTensor with data and original shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
||||||
|
orig_shape (tuple): Original shape of image.
|
||||||
|
"""
|
||||||
|
assert isinstance(data, (torch.Tensor, np.ndarray))
|
||||||
self.data = data
|
self.data = data
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
|
|
||||||
@ -34,19 +40,19 @@ class BaseTensor(SimpleClass):
|
|||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
"""Return a copy of the tensor on CPU memory."""
|
"""Return a copy of the tensor on CPU memory."""
|
||||||
return self.__class__(self.data.cpu(), self.orig_shape)
|
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
||||||
|
|
||||||
def numpy(self):
|
def numpy(self):
|
||||||
"""Return a copy of the tensor as a numpy array."""
|
"""Return a copy of the tensor as a numpy array."""
|
||||||
return self.__class__(self.data.numpy(), self.orig_shape)
|
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
||||||
|
|
||||||
def cuda(self):
|
def cuda(self):
|
||||||
"""Return a copy of the tensor on GPU memory."""
|
"""Return a copy of the tensor on GPU memory."""
|
||||||
return self.__class__(self.data.cuda(), self.orig_shape)
|
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
"""Return a copy of the tensor with the specified device and dtype."""
|
"""Return a copy of the tensor with the specified device and dtype."""
|
||||||
return self.__class__(self.data.to(*args, **kwargs), self.orig_shape)
|
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
||||||
|
|
||||||
def __len__(self): # override len(results)
|
def __len__(self): # override len(results)
|
||||||
"""Return the length of the data tensor."""
|
"""Return the length of the data tensor."""
|
||||||
@ -90,8 +96,8 @@ class Results(SimpleClass):
|
|||||||
self.orig_shape = orig_img.shape[:2]
|
self.orig_shape = orig_img.shape[:2]
|
||||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||||
self.probs = probs if probs is not None else None
|
self.probs = Probs(probs) if probs is not None else None
|
||||||
self.keypoints = keypoints if keypoints is not None else None
|
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||||
self.names = names
|
self.names = names
|
||||||
self.path = path
|
self.path = path
|
||||||
@ -229,13 +235,11 @@ class Results(SimpleClass):
|
|||||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
|
|
||||||
if pred_probs is not None and show_probs:
|
if pred_probs is not None and show_probs:
|
||||||
n5 = min(len(names), 5)
|
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
|
||||||
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
|
||||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, "
|
|
||||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||||
|
|
||||||
if keypoints is not None:
|
if keypoints is not None:
|
||||||
for k in reversed(keypoints):
|
for k in reversed(keypoints.data):
|
||||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||||
|
|
||||||
return annotator.result()
|
return annotator.result()
|
||||||
@ -250,9 +254,7 @@ class Results(SimpleClass):
|
|||||||
if len(self) == 0:
|
if len(self) == 0:
|
||||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
return log_string if probs is not None else f'{log_string}(no detections), '
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
n5 = min(len(self.names), 5)
|
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
|
||||||
log_string += f"{', '.join(f'{self.names[j]} {probs[j]:.2f}' for j in top5i)}, "
|
|
||||||
if boxes:
|
if boxes:
|
||||||
for c in boxes.cls.unique():
|
for c in boxes.cls.unique():
|
||||||
n = (boxes.cls == c).sum() # detections per class
|
n = (boxes.cls == c).sum() # detections per class
|
||||||
@ -274,9 +276,7 @@ class Results(SimpleClass):
|
|||||||
texts = []
|
texts = []
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
# Classify
|
# Classify
|
||||||
n5 = min(len(self.names), 5)
|
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
||||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
|
||||||
[texts.append(f'{probs[j]:.2f} {self.names[j]}') for j in top5i]
|
|
||||||
elif boxes:
|
elif boxes:
|
||||||
# Detect/segment/pose
|
# Detect/segment/pose
|
||||||
for j, d in enumerate(boxes):
|
for j, d in enumerate(boxes):
|
||||||
@ -286,7 +286,7 @@ class Results(SimpleClass):
|
|||||||
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||||
line = (c, *seg)
|
line = (c, *seg)
|
||||||
if kpts is not None:
|
if kpts is not None:
|
||||||
kpt = (kpts[j][:, :2].cpu() / d.orig_shape[[1, 0]]).reshape(-1).tolist()
|
kpt = kpts[j].xyn.reshape(-1).tolist()
|
||||||
line += (*kpt, )
|
line += (*kpt, )
|
||||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
||||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
texts.append(('%g ' * len(line)).rstrip() % line)
|
||||||
@ -322,6 +322,10 @@ class Results(SimpleClass):
|
|||||||
|
|
||||||
def tojson(self, normalize=False):
|
def tojson(self, normalize=False):
|
||||||
"""Convert the object to JSON format."""
|
"""Convert the object to JSON format."""
|
||||||
|
if self.probs is not None:
|
||||||
|
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
||||||
|
return
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Create list of detection dictionaries
|
# Create list of detection dictionaries
|
||||||
@ -338,7 +342,7 @@ class Results(SimpleClass):
|
|||||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
||||||
if self.keypoints is not None:
|
if self.keypoints is not None:
|
||||||
x, y, visible = self.keypoints[i].cpu().unbind(dim=1) # torch Tensor
|
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
@ -386,8 +390,7 @@ class Boxes(BaseTensor):
|
|||||||
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||||
super().__init__(boxes, orig_shape)
|
super().__init__(boxes, orig_shape)
|
||||||
self.is_track = n == 7
|
self.is_track = n == 7
|
||||||
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
self.orig_shape = orig_shape
|
||||||
else np.asarray(orig_shape)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def xyxy(self):
|
def xyxy(self):
|
||||||
@ -419,13 +422,19 @@ class Boxes(BaseTensor):
|
|||||||
@lru_cache(maxsize=2)
|
@lru_cache(maxsize=2)
|
||||||
def xyxyn(self):
|
def xyxyn(self):
|
||||||
"""Return the boxes in xyxy format normalized by original image size."""
|
"""Return the boxes in xyxy format normalized by original image size."""
|
||||||
return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
|
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
||||||
|
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
||||||
|
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
||||||
|
return xyxy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache(maxsize=2)
|
@lru_cache(maxsize=2)
|
||||||
def xywhn(self):
|
def xywhn(self):
|
||||||
"""Return the boxes in xywh format normalized by original image size."""
|
"""Return the boxes in xywh format normalized by original image size."""
|
||||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
xywh = ops.xyxy2xywh(self.xyxy)
|
||||||
|
xywh[..., [0, 2]] /= self.orig_shape[1]
|
||||||
|
xywh[..., [1, 3]] /= self.orig_shape[0]
|
||||||
|
return xywh
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def boxes(self):
|
def boxes(self):
|
||||||
@ -439,11 +448,11 @@ class Masks(BaseTensor):
|
|||||||
A class for storing and manipulating detection masks.
|
A class for storing and manipulating detection masks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||||
orig_shape (tuple): Original image size, in the format (height, width).
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||||
orig_shape (tuple): Original image size, in the format (height, width).
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
Properties:
|
Properties:
|
||||||
@ -496,3 +505,100 @@ class Masks(BaseTensor):
|
|||||||
def pandas(self):
|
def pandas(self):
|
||||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||||
LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.")
|
LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
class Keypoints(BaseTensor):
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating detection keypoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||||
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||||
|
orig_shape (tuple): Original image size, in the format (height, width).
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
xy (list): A list of keypoints (pixels) which includes x, y keypoints of each detection.
|
||||||
|
xyn (list): A list of keypoints (normalized) which includes x, y keypoints of each detection.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
||||||
|
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
||||||
|
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
||||||
|
to(): Returns a copy of the keypoints tensor with the specified device and dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, keypoints, orig_shape) -> None:
|
||||||
|
if keypoints.ndim == 2:
|
||||||
|
keypoints = keypoints[None, :]
|
||||||
|
super().__init__(keypoints, orig_shape)
|
||||||
|
self.has_visible = self.data.shape[-1] == 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def xy(self):
|
||||||
|
return self.data[..., :2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def xyn(self):
|
||||||
|
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
||||||
|
xy[..., 0] /= self.orig_shape[1]
|
||||||
|
xy[..., 1] /= self.orig_shape[0]
|
||||||
|
return xy
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def conf(self):
|
||||||
|
return self.data[..., 3] if self.has_visible else None
|
||||||
|
|
||||||
|
|
||||||
|
class Probs(BaseTensor):
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating classify predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class, ).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class).
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
top5 (list[int]): Top 1 indice.
|
||||||
|
top1 (int): Top 5 indices.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
cpu(): Returns a copy of the probs tensor on CPU memory.
|
||||||
|
numpy(): Returns a copy of the probs tensor as a numpy array.
|
||||||
|
cuda(): Returns a copy of the probs tensor on GPU memory.
|
||||||
|
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, probs, orig_shape=None) -> None:
|
||||||
|
super().__init__(probs, orig_shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def top5(self):
|
||||||
|
"""Return the indices of top 5."""
|
||||||
|
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def top1(self):
|
||||||
|
"""Return the indices of top 1."""
|
||||||
|
return int(self.data.argmax())
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def top5conf(self):
|
||||||
|
"""Return the confidences of top 5."""
|
||||||
|
return self.data[self.top5]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def top1conf(self):
|
||||||
|
"""Return the confidences of top 1."""
|
||||||
|
return self.data[self.top1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user