mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Merge a45ab7029012da80547dc864fc7f6509c93afb67 into 6fbaf42b23f6709f4e34a51430587673e70e151d
This commit is contained in:
commit
6cf1cc3a18
@ -115,6 +115,7 @@ class Results(SimpleClass):
|
|||||||
self.probs = Probs(probs) if probs is not None else None
|
self.probs = Probs(probs) if probs is not None else None
|
||||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||||
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
||||||
|
self.centroids = Centroids(self.boxes.xyxy, self.orig_shape) if self.boxes is not None else None
|
||||||
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
|
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
|
||||||
self.names = names
|
self.names = names
|
||||||
self.path = path
|
self.path = path
|
||||||
@ -198,6 +199,7 @@ class Results(SimpleClass):
|
|||||||
boxes=True,
|
boxes=True,
|
||||||
masks=True,
|
masks=True,
|
||||||
probs=True,
|
probs=True,
|
||||||
|
centroids=True,
|
||||||
show=False,
|
show=False,
|
||||||
save=False,
|
save=False,
|
||||||
filename=None,
|
filename=None,
|
||||||
@ -219,6 +221,7 @@ class Results(SimpleClass):
|
|||||||
boxes (bool): Whether to plot the bounding boxes.
|
boxes (bool): Whether to plot the bounding boxes.
|
||||||
masks (bool): Whether to plot the masks.
|
masks (bool): Whether to plot the masks.
|
||||||
probs (bool): Whether to plot classification probability
|
probs (bool): Whether to plot classification probability
|
||||||
|
centroids (bool): Whether to plot the centroids.
|
||||||
show (bool): Whether to display the annotated image directly.
|
show (bool): Whether to display the annotated image directly.
|
||||||
save (bool): Whether to save the annotated image to `filename`.
|
save (bool): Whether to save the annotated image to `filename`.
|
||||||
filename (str): Filename to save image to if save is True.
|
filename (str): Filename to save image to if save is True.
|
||||||
@ -248,6 +251,7 @@ class Results(SimpleClass):
|
|||||||
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
|
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
|
||||||
pred_masks, show_masks = self.masks, masks
|
pred_masks, show_masks = self.masks, masks
|
||||||
pred_probs, show_probs = self.probs, probs
|
pred_probs, show_probs = self.probs, probs
|
||||||
|
pred_centroids, show_centroids = self.centroids, centroids
|
||||||
annotator = Annotator(
|
annotator = Annotator(
|
||||||
deepcopy(self.orig_img if img is None else img),
|
deepcopy(self.orig_img if img is None else img),
|
||||||
line_width,
|
line_width,
|
||||||
@ -291,6 +295,16 @@ class Results(SimpleClass):
|
|||||||
for k in reversed(self.keypoints.data):
|
for k in reversed(self.keypoints.data):
|
||||||
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
|
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
|
||||||
|
|
||||||
|
# Plot centroids
|
||||||
|
if pred_centroids is not None and show_centroids:
|
||||||
|
for i, centroid in enumerate(pred_centroids.xy):
|
||||||
|
if pred_boxes is not None:
|
||||||
|
c = int(pred_boxes.cls[i])
|
||||||
|
color = colors(c, True)
|
||||||
|
else:
|
||||||
|
color = (0, 255, 0) # Default to green if no class information
|
||||||
|
annotator.dot(centroid, color=color, radius=5)
|
||||||
|
|
||||||
# Show results
|
# Show results
|
||||||
if show:
|
if show:
|
||||||
annotator.show(self.path)
|
annotator.show(self.path)
|
||||||
@ -419,6 +433,13 @@ class Results(SimpleClass):
|
|||||||
"y": (y / h).numpy().round(decimals).tolist(),
|
"y": (y / h).numpy().round(decimals).tolist(),
|
||||||
"visible": visible.numpy().round(decimals).tolist(),
|
"visible": visible.numpy().round(decimals).tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.centroids is not None:
|
||||||
|
result["centroid"] = {
|
||||||
|
"x": round(self.centroids.xy[i][0].item() / w, decimals),
|
||||||
|
"y": round(self.centroids.xy[i][1].item() / h, decimals)
|
||||||
|
}
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -741,3 +762,47 @@ class OBB(BaseTensor):
|
|||||||
y2 = self.xyxyxyxy[..., 1].max(1).values
|
y2 = self.xyxyxyxy[..., 1].max(1).values
|
||||||
xyxy = [x1, y1, x2, y2]
|
xyxy = [x1, y1, x2, y2]
|
||||||
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
|
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class Centroids(BaseTensor):
|
||||||
|
"""
|
||||||
|
A class for storing and manipulating detection centroids.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
xy (torch.Tensor): A tensor containing x, y coordinates of centroids for each detection.
|
||||||
|
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
cpu(): Returns a copy of the centroids tensor on CPU memory.
|
||||||
|
numpy(): Returns a copy of the centroids tensor as a numpy array.
|
||||||
|
cuda(): Returns a copy of the centroids tensor on GPU memory.
|
||||||
|
to(device, dtype): Returns a copy of the centroids tensor with the specified device and dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, boxes, orig_shape) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the Centroids object with bounding boxes and original image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boxes (torch.Tensor): A tensor of bounding boxes in xyxy format.
|
||||||
|
orig_shape (tuple): Original image size in (height, width) format.
|
||||||
|
"""
|
||||||
|
centroids = torch.zeros(boxes.shape[0], 2, device=boxes.device)
|
||||||
|
centroids[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x-coordinate
|
||||||
|
centroids[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y-coordinate
|
||||||
|
super().__init__(centroids, orig_shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def xy(self):
|
||||||
|
"""Returns x, y coordinates of centroids."""
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def xyn(self):
|
||||||
|
"""Returns normalized x, y coordinates of centroids."""
|
||||||
|
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
||||||
|
xy[:, 0] /= self.orig_shape[1]
|
||||||
|
xy[:, 1] /= self.orig_shape[0]
|
||||||
|
return xy
|
@ -593,6 +593,23 @@ class Annotator:
|
|||||||
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
|
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
|
||||||
cv2.line(self.im, center_point, center_bbox, color, thickness)
|
cv2.line(self.im, center_point, center_bbox, color, thickness)
|
||||||
|
|
||||||
|
def dot(self, xy, color=(255, 0, 0), radius=5):
|
||||||
|
"""
|
||||||
|
Draw a dot on the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xy (tuple): The (x, y) coordinates of the dot's center.
|
||||||
|
color (tuple, optional): The color of the dot in BGR format. Defaults to (255, 0, 0) (blue).
|
||||||
|
radius (int, optional): The radius of the dot. Defaults to 5.
|
||||||
|
"""
|
||||||
|
if self.pil:
|
||||||
|
# Convert to PIL ImageDraw
|
||||||
|
draw = ImageDraw.Draw(self.im)
|
||||||
|
draw.ellipse([xy[0]-radius, xy[1]-radius, xy[0]+radius, xy[1]+radius], fill=color, outline=color)
|
||||||
|
else:
|
||||||
|
cv2.circle(self.im, (int(xy[0]), int(xy[1])), radius, color, -1, lineType=cv2.LINE_AA)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||||
@plt_settings()
|
@plt_settings()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user