From 9f73bc77684326578bb61a3d3f88b7b445c0f92b Mon Sep 17 00:00:00 2001 From: wa22 Date: Mon, 27 May 2024 13:43:44 +0800 Subject: [PATCH] fix predict with class filter --- ultralytics/models/yolov10/predict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py index e3670f20..77644e94 100644 --- a/ultralytics/models/yolov10/predict.py +++ b/ultralytics/models/yolov10/predict.py @@ -21,7 +21,9 @@ class YOLOv10DetectionPredictor(DetectionPredictor): preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) mask = preds[..., 4] > self.args.conf - + if self.args.classes is not None: + mask = mask & (preds[..., 5:6] == torch.tensor(self.args.classes, device=preds.device).unsqueeze(0)).any(2) + preds = [p[mask[idx]] for idx, p in enumerate(preds)] if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list