diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py index ad1a7542..83f7bd41 100644 --- a/ultralytics/models/yolov10/predict.py +++ b/ultralytics/models/yolov10/predict.py @@ -6,24 +6,19 @@ from ultralytics.engine.results import Results class YOLOv10DetectionPredictor(DetectionPredictor): def postprocess(self, preds, img, orig_imgs): - if not isinstance(preds, (list, tuple)): - preds = [preds, None] + if isinstance(preds, (list, tuple)): + preds = preds[0] - prediction = preds[0].transpose(-1, -2) - _, _, nd = prediction.shape - nc = nd - 4 - bboxes, scores = prediction.split((4, nd-4), dim=-1) + preds = preds.transpose(-1, -2) + bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det) bboxes = ops.xywh2xyxy(bboxes) - - scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1) - labels = index % nc - index = torch.div(index, nc, rounding_mode='floor') - bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1])) - preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) - assert(preds.shape[0] == 1) + mask = preds[..., 4] > self.args.conf - preds = preds[mask].unsqueeze(0) + + b, _, c = preds.shape + preds = preds.view(-1, preds.shape[-1])[mask.view(-1)] + preds = preds.view(b, -1, c) if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py index 0993681c..92419b2d 100644 --- a/ultralytics/models/yolov10/val.py +++ b/ultralytics/models/yolov10/val.py @@ -11,19 +11,10 @@ class YOLOv10DetectionValidator(DetectionValidator): if self.training: preds = preds["one2one"] - if not isinstance(preds, (list, tuple)): - preds = [preds, None] - - prediction = preds[0].transpose(-1, -2) - _, _, nd = prediction.shape - nc = nd - 4 - assert(self.nc == nc) - bboxes, scores = prediction.split((4, nd-4), dim=-1) - bboxes = ops.xywh2xyxy(bboxes) - - scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1) - labels = index % self.nc - index = torch.div(index, self.nc, rounding_mode='floor') - bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1])) + if isinstance(preds, (list, tuple)): + preds = preds[0] + preds = preds.transpose(-1, -2) + boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det) + bboxes = ops.xywh2xyxy(boxes) return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) \ No newline at end of file diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index a94ceedb..6b2c3ede 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -101,6 +101,8 @@ class Detect(nn.Module): def decode_bboxes(self, bboxes, anchors): """Decode bounding boxes.""" + if self.export: + return dist2bbox(bboxes, anchors, xywh=False, dim=1) return dist2bbox(bboxes, anchors, xywh=True, dim=1) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 439f9440..e6aaf103 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -847,3 +847,18 @@ def clean_str(s): (str): a string with special characters replaced by an underscore _ """ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + +def v10postprocess(preds, max_det): + nc = preds.shape[-1] - 4 + boxes, scores = preds.split([4, nc], dim=-1) + max_scores = scores.amax(dim=-1) + max_scores, index = torch.topk(max_scores, max_det, axis=-1) + index = index.unsqueeze(-1) + boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) + scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) + + scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) + labels = index % nc + index = index // nc + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + return boxes, scores, labels \ No newline at end of file