From 483d7a9050c4ec2a58b95fe0fdf05570294d979a Mon Sep 17 00:00:00 2001
From: wa22 <wa22@mails.tsinghua.edu.cn>
Date: Fri, 24 May 2024 06:23:19 +0000
Subject: [PATCH] update

---
 ultralytics/models/yolov10/predict.py | 2 +-
 ultralytics/models/yolov10/val.py     | 2 +-
 ultralytics/nn/modules/head.py        | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py
index 00de8537..98994f71 100644
--- a/ultralytics/models/yolov10/predict.py
+++ b/ultralytics/models/yolov10/predict.py
@@ -16,7 +16,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
             pass
         else:
             preds = preds.transpose(-1, -2)
-            bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
+            bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
             bboxes = ops.xywh2xyxy(bboxes)
             preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
 
diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py
index bbe11992..1106b9eb 100644
--- a/ultralytics/models/yolov10/val.py
+++ b/ultralytics/models/yolov10/val.py
@@ -15,6 +15,6 @@ class YOLOv10DetectionValidator(DetectionValidator):
             preds = preds[0]
         
         preds = preds.transpose(-1, -2)
-        boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
+        boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
         bboxes = ops.xywh2xyxy(boxes)
         return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
\ No newline at end of file
diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py
index 54b59d6c..5bc7c068 100644
--- a/ultralytics/nn/modules/head.py
+++ b/ultralytics/nn/modules/head.py
@@ -519,7 +519,7 @@ class v10Detect(Detect):
                 return {"one2many": one2many, "one2one": one2one}
             else:
                 assert(self.max_det != -1)
-                boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det)
+                boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
                 return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
         else:
             return {"one2many": one2many, "one2one": one2one}