mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
Make postprocessing compatible with nobuco (#204)
This commit is contained in:
parent
13f6ab770d
commit
d8777c1449
@ -852,12 +852,12 @@ def v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
|
||||
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
|
||||
index = index.unsqueeze(-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
|
||||
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
|
||||
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
|
||||
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
|
||||
labels = index % nc
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user