mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
26 lines
856 B
Python
26 lines
856 B
Python
from ultralytics.models.yolo.detect import DetectionValidator
|
|
from ultralytics.utils import ops
|
|
import torch
|
|
|
|
class YOLOv10DetectionValidator(DetectionValidator):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.args.save_json |= self.is_coco
|
|
|
|
def postprocess(self, preds):
|
|
if isinstance(preds, dict):
|
|
preds = preds["one2one"]
|
|
|
|
if isinstance(preds, (list, tuple)):
|
|
preds = preds[0]
|
|
|
|
if preds.shape[-1] == 6:
|
|
pass
|
|
else:
|
|
preds = preds.transpose(-1, -2)
|
|
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
|
|
bboxes = ops.xywh2xyxy(bboxes)
|
|
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
|
|
|
return preds
|