This commit is contained in:
wa22 2024-05-24 06:23:19 +00:00
parent b714fa8bae
commit 483d7a9050
3 changed files with 3 additions and 3 deletions

View File

@ -16,7 +16,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
pass
else:
preds = preds.transpose(-1, -2)
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
bboxes = ops.xywh2xyxy(bboxes)
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

View File

@ -15,6 +15,6 @@ class YOLOv10DetectionValidator(DetectionValidator):
preds = preds[0]
preds = preds.transpose(-1, -2)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

View File

@ -519,7 +519,7 @@ class v10Detect(Detect):
return {"one2many": one2many, "one2one": one2one}
else:
assert(self.max_det != -1)
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det)
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
else:
return {"one2many": one2many, "one2one": one2one}