diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py index e3670f20..77644e94 100644 --- a/ultralytics/models/yolov10/predict.py +++ b/ultralytics/models/yolov10/predict.py @@ -21,7 +21,9 @@ class YOLOv10DetectionPredictor(DetectionPredictor): preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) mask = preds[..., 4] > self.args.conf - + if self.args.classes is not None: + mask = mask & (preds[..., 5:6] == torch.tensor(self.args.classes, device=preds.device).unsqueeze(0)).any(2) + preds = [p[mask[idx]] for idx, p in enumerate(preds)] if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list