mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.1.5
add OBB Tracking support (#7731)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
This commit is contained in:
parent
12a741c76f
commit
f56dd0f48e
@ -119,6 +119,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.regularize_rboxes
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.masks2segments
|
||||
|
||||
<br><br>
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.4"
|
||||
__version__ = "8.1.5"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
|
@ -115,7 +115,7 @@ class Results(SimpleClass):
|
||||
if v is not None:
|
||||
return len(v)
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
|
||||
@ -123,6 +123,8 @@ class Results(SimpleClass):
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if probs is not None:
|
||||
self.probs = probs
|
||||
if obb is not None:
|
||||
self.obb = OBB(obb, self.orig_shape)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
"""
|
||||
|
@ -225,14 +225,14 @@ class HUBTrainingSession:
|
||||
break # Timeout reached, exit loop
|
||||
|
||||
response = request_func(*args, **kwargs)
|
||||
if progress_total:
|
||||
self._show_upload_progress(progress_total, response)
|
||||
|
||||
if response is None:
|
||||
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
||||
time.sleep(2**i) # Exponential backoff before retrying
|
||||
continue # Skip further processing and retry
|
||||
|
||||
if progress_total:
|
||||
self._show_upload_progress(progress_total, response)
|
||||
|
||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
return response # Success, no need to retry
|
||||
|
||||
|
@ -45,8 +45,9 @@ class OBBPredictor(DetectionPredictor):
|
||||
|
||||
results = []
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
|
||||
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
||||
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
||||
# xywh, r, conf, cls
|
||||
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
|
||||
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
||||
return results
|
||||
|
@ -5,6 +5,8 @@ import numpy as np
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
from .utils import matching
|
||||
from .utils.kalman_filter import KalmanFilterXYAH
|
||||
from ..utils.ops import xywh2ltwh
|
||||
from ..utils import LOGGER
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
@ -35,18 +37,18 @@ class STrack(BaseTrack):
|
||||
activate(kalman_filter, frame_id): Activate a new tracklet.
|
||||
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
|
||||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
|
||||
tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, tlwh, score, cls):
|
||||
def __init__(self, xywh, score, cls):
|
||||
"""Initialize new STrack instance."""
|
||||
super().__init__()
|
||||
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
|
||||
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
@ -54,7 +56,8 @@ class STrack(BaseTrack):
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
self.idx = tlwh[-1]
|
||||
self.idx = xywh[-1]
|
||||
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
@ -123,6 +126,7 @@ class STrack(BaseTrack):
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
@ -145,10 +149,11 @@ class STrack(BaseTrack):
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.angle = new_track.angle
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
@ -162,7 +167,7 @@ class STrack(BaseTrack):
|
||||
return ret
|
||||
|
||||
@property
|
||||
def tlbr(self):
|
||||
def xyxy(self):
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
@ -178,19 +183,26 @@ class STrack(BaseTrack):
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
"""Converts top-left bottom-right format to top-left width height format."""
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
||||
ret = np.asarray(self.tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
"""Converts tlwh bounding box format to tlbr format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
@property
|
||||
def xywha(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
||||
if self.angle is None:
|
||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||
return self.xywh
|
||||
return np.concatenate([self.xywh, self.angle[None]])
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
"""Get current tracking results."""
|
||||
coords = self.xyxy if self.angle is None else self.xywha
|
||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
@ -247,7 +259,7 @@ class BYTETracker:
|
||||
removed_stracks = []
|
||||
|
||||
scores = results.conf
|
||||
bboxes = results.xyxy
|
||||
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
||||
# Add index
|
||||
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
||||
cls = results.cls
|
||||
@ -349,10 +361,8 @@ class BYTETracker:
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
if len(self.removed_stracks) > 1000:
|
||||
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
|
||||
return np.asarray(
|
||||
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
|
@ -25,8 +25,6 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
"""
|
||||
if predictor.args.task == "obb":
|
||||
raise NotImplementedError("ERROR ❌ OBB task does not support track mode!")
|
||||
if hasattr(predictor, "trackers") and persist:
|
||||
return
|
||||
|
||||
@ -54,11 +52,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
||||
bs = predictor.dataset.bs
|
||||
path, im0s = predictor.batch[:2]
|
||||
|
||||
is_obb = predictor.args.task == "obb"
|
||||
for i in range(bs):
|
||||
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video
|
||||
predictor.trackers[i].reset()
|
||||
|
||||
det = predictor.results[i].boxes.cpu().numpy()
|
||||
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
|
||||
if len(det) == 0:
|
||||
continue
|
||||
tracks = predictor.trackers[i].update(det, im0s[i])
|
||||
@ -66,7 +65,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
||||
continue
|
||||
idx = tracks[:, -1].astype(int)
|
||||
predictor.results[i] = predictor.results[i][idx]
|
||||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||
|
||||
update_args = dict()
|
||||
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
|
||||
predictor.results[i].update(**update_args)
|
||||
|
||||
|
||||
def register_tracker(model: object, persist: bool) -> None:
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from ultralytics.utils.metrics import bbox_ioa
|
||||
from ultralytics.utils.metrics import bbox_ioa, batch_probiou
|
||||
|
||||
try:
|
||||
import lap # for linear_assignment
|
||||
@ -74,13 +74,21 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
|
||||
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
|
||||
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if len(atlbrs) and len(btlbrs):
|
||||
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
|
||||
ious = batch_probiou(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
).numpy()
|
||||
else:
|
||||
ious = bbox_ioa(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
||||
iou=True,
|
||||
)
|
||||
return 1 - ious # cost matrix
|
||||
|
||||
|
@ -46,7 +46,7 @@ def on_model_save(trainer):
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}")
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers["ckpt"] = time() # reset timer
|
||||
|
||||
|
@ -239,13 +239,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
||||
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
|
||||
|
||||
Args:
|
||||
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
||||
obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
|
||||
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
||||
obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
|
||||
"""
|
||||
obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
|
||||
obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
|
||||
|
||||
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
||||
x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
|
||||
a1, b1, c1 = _get_covariance_matrix(obb1)
|
||||
|
@ -774,6 +774,24 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
||||
return coords
|
||||
|
||||
|
||||
def regularize_rboxes(rboxes):
|
||||
"""
|
||||
Regularize rotated boxes in range [0, pi/2].
|
||||
|
||||
Args:
|
||||
rboxes (torch.Tensor): (N, 5), xywhr.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The regularized boxes.
|
||||
"""
|
||||
x, y, w, h, t = rboxes.unbind(dim=-1)
|
||||
# Swap edge and angle if h >= w
|
||||
w_ = torch.where(w > h, w, h)
|
||||
h_ = torch.where(w > h, h, w)
|
||||
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
|
||||
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
||||
|
||||
|
||||
def masks2segments(masks, strategy="largest"):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
Loading…
x
Reference in New Issue
Block a user