mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
update
This commit is contained in:
parent
81f5ea9f80
commit
182f0b09a9
@ -6,24 +6,19 @@ from ultralytics.engine.results import Results
|
||||
|
||||
class YOLOv10DetectionPredictor(DetectionPredictor):
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
if not isinstance(preds, (list, tuple)):
|
||||
preds = [preds, None]
|
||||
if isinstance(preds, (list, tuple)):
|
||||
preds = preds[0]
|
||||
|
||||
prediction = preds[0].transpose(-1, -2)
|
||||
_, _, nd = prediction.shape
|
||||
nc = nd - 4
|
||||
bboxes, scores = prediction.split((4, nd-4), dim=-1)
|
||||
preds = preds.transpose(-1, -2)
|
||||
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
|
||||
bboxes = ops.xywh2xyxy(bboxes)
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
|
||||
labels = index % nc
|
||||
index = torch.div(index, nc, rounding_mode='floor')
|
||||
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
|
||||
|
||||
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||
assert(preds.shape[0] == 1)
|
||||
|
||||
mask = preds[..., 4] > self.args.conf
|
||||
preds = preds[mask].unsqueeze(0)
|
||||
|
||||
b, _, c = preds.shape
|
||||
preds = preds.view(-1, preds.shape[-1])[mask.view(-1)]
|
||||
preds = preds.view(b, -1, c)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
@ -11,19 +11,10 @@ class YOLOv10DetectionValidator(DetectionValidator):
|
||||
if self.training:
|
||||
preds = preds["one2one"]
|
||||
|
||||
if not isinstance(preds, (list, tuple)):
|
||||
preds = [preds, None]
|
||||
|
||||
prediction = preds[0].transpose(-1, -2)
|
||||
_, _, nd = prediction.shape
|
||||
nc = nd - 4
|
||||
assert(self.nc == nc)
|
||||
bboxes, scores = prediction.split((4, nd-4), dim=-1)
|
||||
bboxes = ops.xywh2xyxy(bboxes)
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
|
||||
labels = index % self.nc
|
||||
index = torch.div(index, self.nc, rounding_mode='floor')
|
||||
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
|
||||
if isinstance(preds, (list, tuple)):
|
||||
preds = preds[0]
|
||||
|
||||
preds = preds.transpose(-1, -2)
|
||||
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
|
||||
bboxes = ops.xywh2xyxy(boxes)
|
||||
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
@ -101,6 +101,8 @@ class Detect(nn.Module):
|
||||
|
||||
def decode_bboxes(self, bboxes, anchors):
|
||||
"""Decode bounding boxes."""
|
||||
if self.export:
|
||||
return dist2bbox(bboxes, anchors, xywh=False, dim=1)
|
||||
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
|
||||
|
||||
|
||||
|
@ -847,3 +847,18 @@ def clean_str(s):
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
def v10postprocess(preds, max_det):
|
||||
nc = preds.shape[-1] - 4
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
|
||||
index = index.unsqueeze(-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
|
||||
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
|
||||
labels = index % nc
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
return boxes, scores, labels
|
Loading…
x
Reference in New Issue
Block a user