Merge a45ab7029012da80547dc864fc7f6509c93afb67 into 453c6e38a51e9d1d5a2aa5fb7f1014a711913397

This commit is contained in:
Rıza Semih Koca 2025-03-26 06:50:32 +00:00 committed by GitHub
commit 43bcb74d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 0 deletions

View File

@ -115,6 +115,7 @@ class Results(SimpleClass):
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
self.centroids = Centroids(self.boxes.xyxy, self.orig_shape) if self.boxes is not None else None
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
self.names = names
self.path = path
@ -198,6 +199,7 @@ class Results(SimpleClass):
boxes=True,
masks=True,
probs=True,
centroids=True,
show=False,
save=False,
filename=None,
@ -219,6 +221,7 @@ class Results(SimpleClass):
boxes (bool): Whether to plot the bounding boxes.
masks (bool): Whether to plot the masks.
probs (bool): Whether to plot classification probability
centroids (bool): Whether to plot the centroids.
show (bool): Whether to display the annotated image directly.
save (bool): Whether to save the annotated image to `filename`.
filename (str): Filename to save image to if save is True.
@ -248,6 +251,7 @@ class Results(SimpleClass):
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
pred_centroids, show_centroids = self.centroids, centroids
annotator = Annotator(
deepcopy(self.orig_img if img is None else img),
line_width,
@ -291,6 +295,16 @@ class Results(SimpleClass):
for k in reversed(self.keypoints.data):
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
# Plot centroids
if pred_centroids is not None and show_centroids:
for i, centroid in enumerate(pred_centroids.xy):
if pred_boxes is not None:
c = int(pred_boxes.cls[i])
color = colors(c, True)
else:
color = (0, 255, 0) # Default to green if no class information
annotator.dot(centroid, color=color, radius=5)
# Show results
if show:
annotator.show(self.path)
@ -419,6 +433,13 @@ class Results(SimpleClass):
"y": (y / h).numpy().round(decimals).tolist(),
"visible": visible.numpy().round(decimals).tolist(),
}
if self.centroids is not None:
result["centroid"] = {
"x": round(self.centroids.xy[i][0].item() / w, decimals),
"y": round(self.centroids.xy[i][1].item() / h, decimals)
}
results.append(result)
return results
@ -741,3 +762,47 @@ class OBB(BaseTensor):
y2 = self.xyxyxyxy[..., 1].max(1).values
xyxy = [x1, y1, x2, y2]
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
class Centroids(BaseTensor):
"""
A class for storing and manipulating detection centroids.
Attributes:
xy (torch.Tensor): A tensor containing x, y coordinates of centroids for each detection.
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
Methods:
cpu(): Returns a copy of the centroids tensor on CPU memory.
numpy(): Returns a copy of the centroids tensor as a numpy array.
cuda(): Returns a copy of the centroids tensor on GPU memory.
to(device, dtype): Returns a copy of the centroids tensor with the specified device and dtype.
"""
def __init__(self, boxes, orig_shape) -> None:
"""
Initialize the Centroids object with bounding boxes and original image size.
Args:
boxes (torch.Tensor): A tensor of bounding boxes in xyxy format.
orig_shape (tuple): Original image size in (height, width) format.
"""
centroids = torch.zeros(boxes.shape[0], 2, device=boxes.device)
centroids[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x-coordinate
centroids[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y-coordinate
super().__init__(centroids, orig_shape)
@property
@lru_cache(maxsize=1)
def xy(self):
"""Returns x, y coordinates of centroids."""
return self.data
@property
@lru_cache(maxsize=1)
def xyn(self):
"""Returns normalized x, y coordinates of centroids."""
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
xy[:, 0] /= self.orig_shape[1]
xy[:, 1] /= self.orig_shape[0]
return xy

View File

@ -593,6 +593,23 @@ class Annotator:
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
cv2.line(self.im, center_point, center_bbox, color, thickness)
def dot(self, xy, color=(255, 0, 0), radius=5):
"""
Draw a dot on the image.
Args:
xy (tuple): The (x, y) coordinates of the dot's center.
color (tuple, optional): The color of the dot in BGR format. Defaults to (255, 0, 0) (blue).
radius (int, optional): The radius of the dot. Defaults to 5.
"""
if self.pil:
# Convert to PIL ImageDraw
draw = ImageDraw.Draw(self.im)
draw.ellipse([xy[0]-radius, xy[1]-radius, xy[0]+radius, xy[1]+radius], fill=color, outline=color)
else:
cv2.circle(self.im, (int(xy[0]), int(xy[1])), radius, color, -1, lineType=cv2.LINE_AA)
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()