From 6e1f617f8fa14580f94d3ba467dcbc26b426a50e Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 5 Jun 2024 18:59:34 +0100 Subject: [PATCH] Make postprocessing compatible with nobuco --- ultralytics/utils/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index a539fe50..edbb1031 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -852,12 +852,12 @@ def v10postprocess(preds, max_det, nc=80): assert(4 + nc == preds.shape[-1]) boxes, scores = preds.split([4, nc], dim=-1) max_scores = scores.amax(dim=-1) - max_scores, index = torch.topk(max_scores, max_det, axis=-1) + max_scores, index = torch.topk(max_scores, max_det, dim=-1) index = index.unsqueeze(-1) boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) - scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) + scores, index = torch.topk(scores.flatten(1), max_det, dim=-1) labels = index % nc index = index // nc boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))