fix onnx AssertionError

This commit is contained in:
kangsanha 2024-06-04 19:08:46 +09:00 committed by GitHub
parent e5f2f03c55
commit dbf759d925
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,12 @@ class YOLOv10DetectionValidator(DetectionValidator):
if isinstance(preds, (list, tuple)):
preds = preds[0]
if preds.shape[-1] == 6:
pass
else:
preds = preds.transpose(-1, -2)
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
bboxes = ops.xywh2xyxy(boxes)
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
bboxes = ops.xywh2xyxy(bboxes)
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
return preds