This commit is contained in:
wa22 2024-05-23 08:09:57 +00:00
parent 81f5ea9f80
commit 182f0b09a9
4 changed files with 31 additions and 28 deletions

View File

@ -6,24 +6,19 @@ from ultralytics.engine.results import Results
class YOLOv10DetectionPredictor(DetectionPredictor): class YOLOv10DetectionPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
if not isinstance(preds, (list, tuple)): if isinstance(preds, (list, tuple)):
preds = [preds, None] preds = preds[0]
prediction = preds[0].transpose(-1, -2) preds = preds.transpose(-1, -2)
_, _, nd = prediction.shape bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
nc = nd - 4
bboxes, scores = prediction.split((4, nd-4), dim=-1)
bboxes = ops.xywh2xyxy(bboxes) bboxes = ops.xywh2xyxy(bboxes)
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
labels = index % nc
index = torch.div(index, nc, rounding_mode='floor')
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
assert(preds.shape[0] == 1)
mask = preds[..., 4] > self.args.conf mask = preds[..., 4] > self.args.conf
preds = preds[mask].unsqueeze(0)
b, _, c = preds.shape
preds = preds.view(-1, preds.shape[-1])[mask.view(-1)]
preds = preds.view(b, -1, c)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

View File

@ -11,19 +11,10 @@ class YOLOv10DetectionValidator(DetectionValidator):
if self.training: if self.training:
preds = preds["one2one"] preds = preds["one2one"]
if not isinstance(preds, (list, tuple)): if isinstance(preds, (list, tuple)):
preds = [preds, None] preds = preds[0]
prediction = preds[0].transpose(-1, -2)
_, _, nd = prediction.shape
nc = nd - 4
assert(self.nc == nc)
bboxes, scores = prediction.split((4, nd-4), dim=-1)
bboxes = ops.xywh2xyxy(bboxes)
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
labels = index % self.nc
index = torch.div(index, self.nc, rounding_mode='floor')
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
preds = preds.transpose(-1, -2)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

View File

@ -101,6 +101,8 @@ class Detect(nn.Module):
def decode_bboxes(self, bboxes, anchors): def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes.""" """Decode bounding boxes."""
if self.export:
return dist2bbox(bboxes, anchors, xywh=False, dim=1)
return dist2bbox(bboxes, anchors, xywh=True, dim=1) return dist2bbox(bboxes, anchors, xywh=True, dim=1)

View File

@ -847,3 +847,18 @@ def clean_str(s):
(str): a string with special characters replaced by an underscore _ (str): a string with special characters replaced by an underscore _
""" """
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def v10postprocess(preds, max_det):
nc = preds.shape[-1] - 4
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return boxes, scores, labels