mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Refactor val code into new self.match_predictions()
method (#4265)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e9c9b82c42
commit
dde89c744c
@ -22,6 +22,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -199,6 +200,33 @@ class BaseValidator:
|
|||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor,
|
||||||
|
iou: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
||||||
|
true_classes (torch.Tensor): Target class indices of shape(M,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
||||||
|
"""
|
||||||
|
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||||
|
correct_class = true_classes[:, None] == pred_classes
|
||||||
|
for i in range(len(self.iouv)):
|
||||||
|
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||||
|
if x[0].shape[0]:
|
||||||
|
# Concatenate [label, detect, iou]
|
||||||
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||||
|
if x[0].shape[0] > 1:
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
||||||
|
|
||||||
def add_callback(self, event: str, callback):
|
def add_callback(self, event: str, callback):
|
||||||
"""Appends the given callback."""
|
"""Appends the given callback."""
|
||||||
self.callbacks[event].append(callback)
|
self.callbacks[event].append(callback)
|
||||||
|
@ -150,20 +150,7 @@ class FastSAMValidator(DetectionValidator):
|
|||||||
else: # boxes
|
else: # boxes
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
|
||||||
for i in range(len(self.iouv)):
|
|
||||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
|
||||||
if x[0].shape[0]:
|
|
||||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
|
||||||
1).cpu().numpy() # [label, detect, iou]
|
|
||||||
if x[0].shape[0] > 1:
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
||||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
||||||
correct[matches[:, 1].astype(int), i] = True
|
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plots validation samples with bounding box labels."""
|
"""Plots validation samples with bounding box labels."""
|
||||||
|
@ -163,20 +163,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
correct (array[N, 10]), for 10 IoU levels
|
correct (array[N, 10]), for 10 IoU levels
|
||||||
"""
|
"""
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
|
||||||
for i in range(len(self.iouv)):
|
|
||||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
|
||||||
if x[0].shape[0]:
|
|
||||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
|
||||||
1).cpu().numpy() # [label, detect, iou]
|
|
||||||
if x[0].shape[0] > 1:
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
||||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
||||||
correct[matches[:, 1].astype(int), i] = True
|
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='val', batch=None):
|
def build_dataset(self, img_path, mode='val', batch=None):
|
||||||
"""Build YOLO Dataset
|
"""Build YOLO Dataset
|
||||||
|
@ -128,20 +128,7 @@ class PoseValidator(DetectionValidator):
|
|||||||
else: # boxes
|
else: # boxes
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
|
||||||
for i in range(len(self.iouv)):
|
|
||||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
|
||||||
if x[0].shape[0]:
|
|
||||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
|
||||||
1).cpu().numpy() # [label, detect, iou]
|
|
||||||
if x[0].shape[0] > 1:
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
||||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
||||||
correct[matches[:, 1].astype(int), i] = True
|
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||||
|
@ -150,20 +150,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
else: # boxes
|
else: # boxes
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
|
||||||
for i in range(len(self.iouv)):
|
|
||||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
|
||||||
if x[0].shape[0]:
|
|
||||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
|
||||||
1).cpu().numpy() # [label, detect, iou]
|
|
||||||
if x[0].shape[0] > 1:
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
||||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
||||||
correct[matches[:, 1].astype(int), i] = True
|
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
"""Plots validation samples with bounding box labels."""
|
"""Plots validation samples with bounding box labels."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user