mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 17:05:40 +08:00 
			
		
		
		
	Cleanup tracker and remove unused functions (#4374)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									4093c1fd64
								
							
						
					
					
						commit
						60cad0c592
					
				@ -3,10 +3,6 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
import scipy.linalg
 | 
			
		||||
 | 
			
		||||
# Table for the 0.95 quantile of the chi-square distribution with N degrees of freedom (contains values for N=1, ..., 9)
 | 
			
		||||
# Taken from MATLAB/Octave's chi2inv function and used as Mahalanobis gating threshold.
 | 
			
		||||
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KalmanFilterXYAH:
 | 
			
		||||
    """
 | 
			
		||||
@ -235,7 +231,7 @@ class KalmanFilterXYAH:
 | 
			
		||||
            raise ValueError('invalid distance metric')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KalmanFilterXYWH:
 | 
			
		||||
class KalmanFilterXYWH(KalmanFilterXYAH):
 | 
			
		||||
    """
 | 
			
		||||
    For BoT-SORT
 | 
			
		||||
    A simple Kalman filter for tracking bounding boxes in image space.
 | 
			
		||||
@ -253,22 +249,6 @@ class KalmanFilterXYWH:
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        """Initialize Kalman filter model matrices with motion and observation uncertainties."""
 | 
			
		||||
        ndim, dt = 4, 1.
 | 
			
		||||
 | 
			
		||||
        # Create Kalman filter model matrices.
 | 
			
		||||
        self._motion_mat = np.eye(2 * ndim, 2 * ndim)
 | 
			
		||||
        for i in range(ndim):
 | 
			
		||||
            self._motion_mat[i, ndim + i] = dt
 | 
			
		||||
        self._update_mat = np.eye(ndim, 2 * ndim)
 | 
			
		||||
 | 
			
		||||
        # Motion and observation uncertainty are chosen relative to the current
 | 
			
		||||
        # state estimate. These weights control the amount of uncertainty in
 | 
			
		||||
        # the model. This is a bit hacky.
 | 
			
		||||
        self._std_weight_position = 1. / 20
 | 
			
		||||
        self._std_weight_velocity = 1. / 160
 | 
			
		||||
 | 
			
		||||
    def initiate(self, measurement):
 | 
			
		||||
        """Create track from unassociated measurement.
 | 
			
		||||
 | 
			
		||||
@ -409,54 +389,4 @@ class KalmanFilterXYWH:
 | 
			
		||||
            Returns the measurement-corrected state distribution.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        projected_mean, projected_cov = self.project(mean, covariance)
 | 
			
		||||
 | 
			
		||||
        chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
 | 
			
		||||
        kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
 | 
			
		||||
                                             np.dot(covariance, self._update_mat.T).T,
 | 
			
		||||
                                             check_finite=False).T
 | 
			
		||||
        innovation = measurement - projected_mean
 | 
			
		||||
 | 
			
		||||
        new_mean = mean + np.dot(innovation, kalman_gain.T)
 | 
			
		||||
        new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
 | 
			
		||||
        return new_mean, new_covariance
 | 
			
		||||
 | 
			
		||||
    def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
 | 
			
		||||
        """Compute gating distance between state distribution and measurements.
 | 
			
		||||
        A suitable distance threshold can be obtained from `chi2inv95`. If
 | 
			
		||||
        `only_position` is False, the chi-square distribution has 4 degrees of
 | 
			
		||||
        freedom, otherwise 2.
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        mean : ndarray
 | 
			
		||||
            Mean vector over the state distribution (8 dimensional).
 | 
			
		||||
        covariance : ndarray
 | 
			
		||||
            Covariance of the state distribution (8x8 dimensional).
 | 
			
		||||
        measurements : ndarray
 | 
			
		||||
            An Nx4 dimensional matrix of N measurements, each in
 | 
			
		||||
            format (x, y, a, h) where (x, y) is the bounding box center
 | 
			
		||||
            position, a the aspect ratio, and h the height.
 | 
			
		||||
        only_position : Optional[bool]
 | 
			
		||||
            If True, distance computation is done with respect to the bounding
 | 
			
		||||
            box center position only.
 | 
			
		||||
        Returns
 | 
			
		||||
        -------
 | 
			
		||||
        ndarray
 | 
			
		||||
            Returns an array of length N, where the i-th element contains the
 | 
			
		||||
            squared Mahalanobis distance between (mean, covariance) and
 | 
			
		||||
            `measurements[i]`.
 | 
			
		||||
        """
 | 
			
		||||
        mean, covariance = self.project(mean, covariance)
 | 
			
		||||
        if only_position:
 | 
			
		||||
            mean, covariance = mean[:2], covariance[:2, :2]
 | 
			
		||||
            measurements = measurements[:, :2]
 | 
			
		||||
 | 
			
		||||
        d = measurements - mean
 | 
			
		||||
        if metric == 'gaussian':
 | 
			
		||||
            return np.sum(d * d, axis=1)
 | 
			
		||||
        elif metric == 'maha':
 | 
			
		||||
            cholesky_factor = np.linalg.cholesky(covariance)
 | 
			
		||||
            z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
 | 
			
		||||
            return np.sum(z * z, axis=0)  # square maha
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError('invalid distance metric')
 | 
			
		||||
        return super().update(mean, covariance, measurement)
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ import numpy as np
 | 
			
		||||
import scipy
 | 
			
		||||
from scipy.spatial.distance import cdist
 | 
			
		||||
 | 
			
		||||
from .kalman_filter import chi2inv95
 | 
			
		||||
from ultralytics.utils.metrics import bbox_ioa
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import lap  # for linear_assignment
 | 
			
		||||
@ -17,36 +17,6 @@ except (ImportError, AssertionError, AttributeError):
 | 
			
		||||
    import lap
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_matches(m1, m2, shape):
 | 
			
		||||
    """Merge two sets of matches and return matched and unmatched indices."""
 | 
			
		||||
    O, P, Q = shape
 | 
			
		||||
    m1 = np.asarray(m1)
 | 
			
		||||
    m2 = np.asarray(m2)
 | 
			
		||||
 | 
			
		||||
    M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
 | 
			
		||||
    M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
 | 
			
		||||
 | 
			
		||||
    mask = M1 * M2
 | 
			
		||||
    match = mask.nonzero()
 | 
			
		||||
    match = list(zip(match[0], match[1]))
 | 
			
		||||
    unmatched_O = tuple(set(range(O)) - {i for i, j in match})
 | 
			
		||||
    unmatched_Q = tuple(set(range(Q)) - {j for i, j in match})
 | 
			
		||||
 | 
			
		||||
    return match, unmatched_O, unmatched_Q
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _indices_to_matches(cost_matrix, indices, thresh):
 | 
			
		||||
    """Return matched and unmatched indices given a cost matrix, indices, and a threshold."""
 | 
			
		||||
    matched_cost = cost_matrix[tuple(zip(*indices))]
 | 
			
		||||
    matched_mask = (matched_cost <= thresh)
 | 
			
		||||
 | 
			
		||||
    matches = indices[matched_mask]
 | 
			
		||||
    unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
 | 
			
		||||
    unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
 | 
			
		||||
 | 
			
		||||
    return matches, unmatched_a, unmatched_b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
 | 
			
		||||
    """Linear assignment implementations with scipy and lap.lapjv."""
 | 
			
		||||
    if cost_matrix.size == 0:
 | 
			
		||||
@ -70,26 +40,6 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
 | 
			
		||||
    return matches, unmatched_a, unmatched_b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ious(atlbrs, btlbrs):
 | 
			
		||||
    """
 | 
			
		||||
    Compute cost based on IoU
 | 
			
		||||
    :type atlbrs: list[tlbr] | np.ndarray
 | 
			
		||||
    :type atlbrs: list[tlbr] | np.ndarray
 | 
			
		||||
 | 
			
		||||
    :rtype ious np.ndarray
 | 
			
		||||
    """
 | 
			
		||||
    ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
 | 
			
		||||
    if ious.size == 0:
 | 
			
		||||
        return ious
 | 
			
		||||
    ious = bbox_ious(np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32))
 | 
			
		||||
 | 
			
		||||
    # TODO: replace bbox_ious() with numpy-capable update of utils.metrics.box_iou
 | 
			
		||||
    # from ...utils.metrics import box_iou
 | 
			
		||||
    # ious = box_iou()
 | 
			
		||||
 | 
			
		||||
    return ious
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def iou_distance(atracks, btracks):
 | 
			
		||||
    """
 | 
			
		||||
    Compute cost based on IoU
 | 
			
		||||
@ -106,26 +56,13 @@ def iou_distance(atracks, btracks):
 | 
			
		||||
    else:
 | 
			
		||||
        atlbrs = [track.tlbr for track in atracks]
 | 
			
		||||
        btlbrs = [track.tlbr for track in btracks]
 | 
			
		||||
    return 1 - ious(atlbrs, btlbrs)  # cost matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def v_iou_distance(atracks, btracks):
 | 
			
		||||
    """
 | 
			
		||||
    Compute cost based on IoU
 | 
			
		||||
    :type atracks: list[STrack]
 | 
			
		||||
    :type btracks: list[STrack]
 | 
			
		||||
 | 
			
		||||
    :rtype cost_matrix np.ndarray
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
 | 
			
		||||
            or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
 | 
			
		||||
        atlbrs = atracks
 | 
			
		||||
        btlbrs = btracks
 | 
			
		||||
    else:
 | 
			
		||||
        atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
 | 
			
		||||
        btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
 | 
			
		||||
    return 1 - ious(atlbrs, btlbrs)  # cost matrix
 | 
			
		||||
    ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
 | 
			
		||||
    if len(atlbrs) and len(btlbrs):
 | 
			
		||||
        ious = bbox_ioa(np.ascontiguousarray(atlbrs, dtype=np.float32),
 | 
			
		||||
                        np.ascontiguousarray(btlbrs, dtype=np.float32),
 | 
			
		||||
                        iou=True)
 | 
			
		||||
    return 1 - ious  # cost matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def embedding_distance(tracks, detections, metric='cosine'):
 | 
			
		||||
@ -147,46 +84,6 @@ def embedding_distance(tracks, detections, metric='cosine'):
 | 
			
		||||
    return cost_matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
 | 
			
		||||
    """Apply gating to the cost matrix based on predicted tracks and detected objects."""
 | 
			
		||||
    if cost_matrix.size == 0:
 | 
			
		||||
        return cost_matrix
 | 
			
		||||
    gating_dim = 2 if only_position else 4
 | 
			
		||||
    gating_threshold = chi2inv95[gating_dim]
 | 
			
		||||
    measurements = np.asarray([det.to_xyah() for det in detections])
 | 
			
		||||
    for row, track in enumerate(tracks):
 | 
			
		||||
        gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
 | 
			
		||||
        cost_matrix[row, gating_distance > gating_threshold] = np.inf
 | 
			
		||||
    return cost_matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
 | 
			
		||||
    """Fuse motion between tracks and detections with gating and Kalman filtering."""
 | 
			
		||||
    if cost_matrix.size == 0:
 | 
			
		||||
        return cost_matrix
 | 
			
		||||
    gating_dim = 2 if only_position else 4
 | 
			
		||||
    gating_threshold = chi2inv95[gating_dim]
 | 
			
		||||
    measurements = np.asarray([det.to_xyah() for det in detections])
 | 
			
		||||
    for row, track in enumerate(tracks):
 | 
			
		||||
        gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
 | 
			
		||||
        cost_matrix[row, gating_distance > gating_threshold] = np.inf
 | 
			
		||||
        cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
 | 
			
		||||
    return cost_matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fuse_iou(cost_matrix, tracks, detections):
 | 
			
		||||
    """Fuses ReID and IoU similarity matrices to yield a cost matrix for object tracking."""
 | 
			
		||||
    if cost_matrix.size == 0:
 | 
			
		||||
        return cost_matrix
 | 
			
		||||
    reid_sim = 1 - cost_matrix
 | 
			
		||||
    iou_dist = iou_distance(tracks, detections)
 | 
			
		||||
    iou_sim = 1 - iou_dist
 | 
			
		||||
    fuse_sim = reid_sim * (1 + iou_sim) / 2
 | 
			
		||||
    # det_scores = np.array([det.score for det in detections])
 | 
			
		||||
    # det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
 | 
			
		||||
    return 1 - fuse_sim  # fuse cost
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fuse_score(cost_matrix, detections):
 | 
			
		||||
    """Fuses cost matrix with detection scores to produce a single similarity matrix."""
 | 
			
		||||
    if cost_matrix.size == 0:
 | 
			
		||||
@ -196,36 +93,3 @@ def fuse_score(cost_matrix, detections):
 | 
			
		||||
    det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
 | 
			
		||||
    fuse_sim = iou_sim * det_scores
 | 
			
		||||
    return 1 - fuse_sim  # fuse_cost
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bbox_ious(box1, box2, eps=1e-7):
 | 
			
		||||
    """
 | 
			
		||||
    Calculate the Intersection over Union (IoU) between pairs of bounding boxes.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        box1 (np.array): A numpy array of shape (n, 4) representing 'n' bounding boxes.
 | 
			
		||||
                         Each row is in the format (x1, y1, x2, y2).
 | 
			
		||||
        box2 (np.array): A numpy array of shape (m, 4) representing 'm' bounding boxes.
 | 
			
		||||
                         Each row is in the format (x1, y1, x2, y2).
 | 
			
		||||
        eps (float, optional): A small constant to prevent division by zero. Defaults to 1e-7.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (np.array): A numpy array of shape (n, m) representing the IoU scores for each pair
 | 
			
		||||
                    of bounding boxes from box1 and box2.
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        The bounding box coordinates are expected to be in the format (x1, y1, x2, y2).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Get the coordinates of bounding boxes
 | 
			
		||||
    b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
 | 
			
		||||
    b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
 | 
			
		||||
 | 
			
		||||
    # Intersection area
 | 
			
		||||
    inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
 | 
			
		||||
                 (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
 | 
			
		||||
 | 
			
		||||
    # box2 area
 | 
			
		||||
    box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
 | 
			
		||||
    box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
 | 
			
		||||
    return inter_area / (box2_area + box1_area[:, None] - inter_area + eps)
 | 
			
		||||
 | 
			
		||||
@ -21,13 +21,14 @@ def box_area(box):
 | 
			
		||||
    return (box[2] - box[0]) * (box[3] - box[1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bbox_ioa(box1, box2, eps=1e-7):
 | 
			
		||||
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
 | 
			
		||||
    """
 | 
			
		||||
    Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
 | 
			
		||||
        box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
 | 
			
		||||
        iou (bool): Calculate the standard iou if True else return inter_area/box2_area.
 | 
			
		||||
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
@ -43,10 +44,13 @@ def bbox_ioa(box1, box2, eps=1e-7):
 | 
			
		||||
                 (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
 | 
			
		||||
 | 
			
		||||
    # box2 area
 | 
			
		||||
    box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
 | 
			
		||||
    area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
 | 
			
		||||
    if iou:
 | 
			
		||||
        box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
 | 
			
		||||
        area = area + box1_area[:, None] - inter_area
 | 
			
		||||
 | 
			
		||||
    # Intersection over box2 area
 | 
			
		||||
    return inter_area / box2_area
 | 
			
		||||
    return inter_area / (area + eps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def box_iou(box1, box2, eps=1e-7):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user