Make postprocessing compatible with nobuco (#204)

This commit is contained in:
Pbatch 2024-06-06 00:57:59 +01:00 committed by GitHub
parent 13f6ab770d
commit d8777c1449
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -852,12 +852,12 @@ def v10postprocess(preds, max_det, nc=80):
assert(4 + nc == preds.shape[-1])
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))