diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index edbb1031..481e006d 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -210,6 +210,10 @@ def non_max_suppression( # Checks assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + + if isinstance(prediction, dict) and 'one2one' in prediction: # model.pt, output = {"one2many": torch.Tensor, "one2one": torch.Tensor } + prediction = prediction['one2one'] + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output