mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-13 17:55:38 +08:00
fix for predicting with batch_size > 1
This commit is contained in:
parent
f19f3e521f
commit
39c20fc953
@ -11,28 +11,33 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
if isinstance(preds, (list, tuple)):
|
if isinstance(preds, (list, tuple)):
|
||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
|
|
||||||
if preds.shape[-1] == 6:
|
if preds.shape[-1] != 6:
|
||||||
pass
|
|
||||||
else:
|
|
||||||
preds = preds.transpose(-1, -2)
|
preds = preds.transpose(-1, -2)
|
||||||
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
|
bboxes, scores, labels = ops.v10postprocess(
|
||||||
|
preds, self.args.max_det, preds.shape[-1] - 4
|
||||||
|
)
|
||||||
bboxes = ops.xywh2xyxy(bboxes)
|
bboxes = ops.xywh2xyxy(bboxes)
|
||||||
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
preds = torch.cat(
|
||||||
|
[bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
mask = preds[..., 4] > self.args.conf
|
mask = preds[..., 4] > self.args.conf
|
||||||
|
|
||||||
b, _, c = preds.shape
|
# Filter predictions using the mask and keep batch dimension
|
||||||
preds = preds.view(-1, preds.shape[-1])[mask.view(-1)]
|
filtered_preds = [p[mask[idx]] for idx, p in enumerate(preds)]
|
||||||
preds = preds.view(b, -1, c)
|
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(
|
||||||
|
orig_imgs, list
|
||||||
|
): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(filtered_preds):
|
||||||
orig_img = orig_imgs[i]
|
orig_img = orig_imgs[i]
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
results.append(
|
||||||
|
Results(orig_img, path=img_path, names=self.model.names, boxes=pred)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
Loading…
x
Reference in New Issue
Block a user