mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Merge a45ab7029012da80547dc864fc7f6509c93afb67 into 6fbaf42b23f6709f4e34a51430587673e70e151d
This commit is contained in:
		
						commit
						6cf1cc3a18
					
				@ -115,6 +115,7 @@ class Results(SimpleClass):
 | 
			
		||||
        self.probs = Probs(probs) if probs is not None else None
 | 
			
		||||
        self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
 | 
			
		||||
        self.obb = OBB(obb, self.orig_shape) if obb is not None else None
 | 
			
		||||
        self.centroids = Centroids(self.boxes.xyxy, self.orig_shape) if self.boxes is not None else None
 | 
			
		||||
        self.speed = {"preprocess": None, "inference": None, "postprocess": None}  # milliseconds per image
 | 
			
		||||
        self.names = names
 | 
			
		||||
        self.path = path
 | 
			
		||||
@ -198,6 +199,7 @@ class Results(SimpleClass):
 | 
			
		||||
        boxes=True,
 | 
			
		||||
        masks=True,
 | 
			
		||||
        probs=True,
 | 
			
		||||
        centroids=True,
 | 
			
		||||
        show=False,
 | 
			
		||||
        save=False,
 | 
			
		||||
        filename=None,
 | 
			
		||||
@ -219,6 +221,7 @@ class Results(SimpleClass):
 | 
			
		||||
            boxes (bool): Whether to plot the bounding boxes.
 | 
			
		||||
            masks (bool): Whether to plot the masks.
 | 
			
		||||
            probs (bool): Whether to plot classification probability
 | 
			
		||||
            centroids (bool): Whether to plot the centroids.
 | 
			
		||||
            show (bool): Whether to display the annotated image directly.
 | 
			
		||||
            save (bool): Whether to save the annotated image to `filename`.
 | 
			
		||||
            filename (str): Filename to save image to if save is True.
 | 
			
		||||
@ -248,6 +251,7 @@ class Results(SimpleClass):
 | 
			
		||||
        pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
 | 
			
		||||
        pred_masks, show_masks = self.masks, masks
 | 
			
		||||
        pred_probs, show_probs = self.probs, probs
 | 
			
		||||
        pred_centroids, show_centroids = self.centroids, centroids
 | 
			
		||||
        annotator = Annotator(
 | 
			
		||||
            deepcopy(self.orig_img if img is None else img),
 | 
			
		||||
            line_width,
 | 
			
		||||
@ -291,6 +295,16 @@ class Results(SimpleClass):
 | 
			
		||||
            for k in reversed(self.keypoints.data):
 | 
			
		||||
                annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
 | 
			
		||||
 | 
			
		||||
        # Plot centroids
 | 
			
		||||
        if pred_centroids is not None and show_centroids:
 | 
			
		||||
            for i, centroid in enumerate(pred_centroids.xy):
 | 
			
		||||
                if pred_boxes is not None:
 | 
			
		||||
                    c = int(pred_boxes.cls[i])
 | 
			
		||||
                    color = colors(c, True)
 | 
			
		||||
                else:
 | 
			
		||||
                    color = (0, 255, 0)  # Default to green if no class information
 | 
			
		||||
                annotator.dot(centroid, color=color, radius=5)
 | 
			
		||||
 | 
			
		||||
        # Show results
 | 
			
		||||
        if show:
 | 
			
		||||
            annotator.show(self.path)
 | 
			
		||||
@ -419,6 +433,13 @@ class Results(SimpleClass):
 | 
			
		||||
                    "y": (y / h).numpy().round(decimals).tolist(),
 | 
			
		||||
                    "visible": visible.numpy().round(decimals).tolist(),
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            if self.centroids is not None:
 | 
			
		||||
                result["centroid"] = {
 | 
			
		||||
                    "x": round(self.centroids.xy[i][0].item() / w, decimals),
 | 
			
		||||
                    "y": round(self.centroids.xy[i][1].item() / h, decimals)
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            results.append(result)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
@ -741,3 +762,47 @@ class OBB(BaseTensor):
 | 
			
		||||
        y2 = self.xyxyxyxy[..., 1].max(1).values
 | 
			
		||||
        xyxy = [x1, y1, x2, y2]
 | 
			
		||||
        return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Centroids(BaseTensor):
 | 
			
		||||
    """
 | 
			
		||||
    A class for storing and manipulating detection centroids.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        xy (torch.Tensor): A tensor containing x, y coordinates of centroids for each detection.
 | 
			
		||||
        xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
 | 
			
		||||
 | 
			
		||||
    Methods:
 | 
			
		||||
        cpu(): Returns a copy of the centroids tensor on CPU memory.
 | 
			
		||||
        numpy(): Returns a copy of the centroids tensor as a numpy array.
 | 
			
		||||
        cuda(): Returns a copy of the centroids tensor on GPU memory.
 | 
			
		||||
        to(device, dtype): Returns a copy of the centroids tensor with the specified device and dtype.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, boxes, orig_shape) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the Centroids object with bounding boxes and original image size.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            boxes (torch.Tensor): A tensor of bounding boxes in xyxy format.
 | 
			
		||||
            orig_shape (tuple): Original image size in (height, width) format.
 | 
			
		||||
        """
 | 
			
		||||
        centroids = torch.zeros(boxes.shape[0], 2, device=boxes.device)
 | 
			
		||||
        centroids[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2  # x-coordinate
 | 
			
		||||
        centroids[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2  # y-coordinate
 | 
			
		||||
        super().__init__(centroids, orig_shape)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @lru_cache(maxsize=1)
 | 
			
		||||
    def xy(self):
 | 
			
		||||
        """Returns x, y coordinates of centroids."""
 | 
			
		||||
        return self.data
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @lru_cache(maxsize=1)
 | 
			
		||||
    def xyn(self):
 | 
			
		||||
        """Returns normalized x, y coordinates of centroids."""
 | 
			
		||||
        xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
 | 
			
		||||
        xy[:, 0] /= self.orig_shape[1]
 | 
			
		||||
        xy[:, 1] /= self.orig_shape[0]
 | 
			
		||||
        return xy
 | 
			
		||||
@ -593,6 +593,23 @@ class Annotator:
 | 
			
		||||
        cv2.circle(self.im, center_bbox, pins_radius, color, -1)
 | 
			
		||||
        cv2.line(self.im, center_point, center_bbox, color, thickness)
 | 
			
		||||
 | 
			
		||||
    def dot(self, xy, color=(255, 0, 0), radius=5):
 | 
			
		||||
        """
 | 
			
		||||
        Draw a dot on the image.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            xy (tuple): The (x, y) coordinates of the dot's center.
 | 
			
		||||
            color (tuple, optional): The color of the dot in BGR format. Defaults to (255, 0, 0) (blue).
 | 
			
		||||
            radius (int, optional): The radius of the dot. Defaults to 5.
 | 
			
		||||
        """
 | 
			
		||||
        if self.pil:
 | 
			
		||||
            # Convert to PIL ImageDraw
 | 
			
		||||
            draw = ImageDraw.Draw(self.im)
 | 
			
		||||
            draw.ellipse([xy[0]-radius, xy[1]-radius, xy[0]+radius, xy[1]+radius], fill=color, outline=color)
 | 
			
		||||
        else:
 | 
			
		||||
            cv2.circle(self.im, (int(xy[0]), int(xy[1])), radius, color, -1, lineType=cv2.LINE_AA)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@TryExcept()  # known issue https://github.com/ultralytics/yolov5/issues/5395
 | 
			
		||||
@plt_settings()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user