fix for predicting with batch_size > 1 (#47)

* fix for predicting with batch_size > 1
This commit is contained in:
Thomas Friedel 2024-05-26 13:13:12 +02:00 committed by GitHub
parent 7a8f9643d2
commit 799ff3be47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
if isinstance(preds, (list, tuple)): if isinstance(preds, (list, tuple)):
preds = preds[0] preds = preds[0]
if preds.shape[-1] == 6: if preds.shape[-1] == 6:
pass pass
else: else:
@ -22,9 +22,7 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
mask = preds[..., 4] > self.args.conf mask = preds[..., 4] > self.args.conf
b, _, c = preds.shape preds = [p[mask[idx]] for idx, p in enumerate(preds)]
preds = preds.view(-1, preds.shape[-1])[mask.view(-1)]
preds = preds.view(b, -1, c)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)