mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
fix predict with class filter
This commit is contained in:
parent
4dabf44202
commit
9f73bc7768
@ -21,7 +21,9 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
|
|||||||
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
mask = preds[..., 4] > self.args.conf
|
mask = preds[..., 4] > self.args.conf
|
||||||
|
if self.args.classes is not None:
|
||||||
|
mask = mask & (preds[..., 5:6] == torch.tensor(self.args.classes, device=preds.device).unsqueeze(0)).any(2)
|
||||||
|
|
||||||
preds = [p[mask[idx]] for idx, p in enumerate(preds)]
|
preds = [p[mask[idx]] for idx, p in enumerate(preds)]
|
||||||
|
|
||||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
Loading…
x
Reference in New Issue
Block a user