mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
fix for predicting with batch_size > 1 (#47)
* fix for predicting with batch_size > 1
This commit is contained in:
parent
7a8f9643d2
commit
799ff3be47
@ -11,7 +11,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
if isinstance(preds, (list, tuple)):
|
if isinstance(preds, (list, tuple)):
|
||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
|
|
||||||
if preds.shape[-1] == 6:
|
if preds.shape[-1] == 6:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -22,9 +22,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
|
|||||||
|
|
||||||
mask = preds[..., 4] > self.args.conf
|
mask = preds[..., 4] > self.args.conf
|
||||||
|
|
||||||
b, _, c = preds.shape
|
preds = [p[mask[idx]] for idx, p in enumerate(preds)]
|
||||||
preds = preds.view(-1, preds.shape[-1])[mask.view(-1)]
|
|
||||||
preds = preds.view(b, -1, c)
|
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user