mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
227 lines
9.3 KiB
Python
227 lines
9.3 KiB
Python
from ultralytics.models.yolo.detect import DetectionValidator
|
|
from ultralytics.utils import ops
|
|
import torch
|
|
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
|
import torch.nn.functional as F
|
|
class YOLOv10SegValidator(DetectionValidator):
|
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
self.args.save_json |= self.is_coco
|
|
self.plot_masks = None
|
|
self.process = None
|
|
self.args.task = "segment"
|
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
|
def init_metrics(self, model):
|
|
"""Initialize metrics and select mask processing function based on save_json flag."""
|
|
super().init_metrics(model)
|
|
self.plot_masks = []
|
|
if self.args.save_json:
|
|
# check_requirements("pycocotools>=2.0.6")
|
|
self.process = ops.process_mask_upsample # more accurate
|
|
else:
|
|
self.process = ops.process_mask # faster
|
|
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[])
|
|
def get_desc(self):
|
|
"""Return a formatted description of evaluation metrics."""
|
|
return ("%22s" + "%11s" * 10) % (
|
|
"Class",
|
|
"Images",
|
|
"Instances",
|
|
"Box(P",
|
|
"R",
|
|
"mAP50",
|
|
"mAP50-95)",
|
|
"Mask(P",
|
|
"R",
|
|
"mAP50",
|
|
"mAP50-95)",
|
|
)
|
|
def _prepare_batch(self, si, batch):
|
|
"""Prepares a batch for training or inference by processing images and targets."""
|
|
prepared_batch = super()._prepare_batch(si, batch)
|
|
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
|
|
prepared_batch["masks"] = batch["masks"][midx]
|
|
return prepared_batch
|
|
|
|
def finalize_metrics(self, *args, **kwargs):
|
|
"""Sets speed and confusion matrix for evaluation metrics."""
|
|
self.metrics.speed = self.speed
|
|
self.metrics.confusion_matrix = self.confusion_matrix
|
|
def preprocess(self, batch):
|
|
"""Preprocesses batch by converting masks to float and sending to device."""
|
|
batch = super().preprocess(batch)
|
|
batch["masks"] = batch["masks"].to(self.device).float()
|
|
return batch
|
|
def postprocess(self, preds):
|
|
coef,proto = None,None
|
|
if isinstance(preds, dict):
|
|
coef = preds["coef"] # [1,32,5294]
|
|
proto = preds["proto"] # [1,32,92,168]
|
|
preds = preds["one2one"]
|
|
|
|
if isinstance(preds, (list, tuple)):
|
|
preds = preds[0]
|
|
|
|
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
|
|
if preds.shape[-1] == 6:
|
|
return preds
|
|
else:
|
|
preds = preds.transpose(-1, -2) # [1,6006,5]
|
|
coef = coef.transpose(-1, -2)
|
|
# bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4)
|
|
bboxes, scores, labels,segmask = ops.v10segpostprocess([preds,coef,proto], self.args.max_det, preds.shape[-1]-4)
|
|
bboxes = ops.xywh2xyxy(bboxes)
|
|
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
|
return [preds,segmask]
|
|
def process_mask(self,segmask, bboxes, shape, upsample=False):
|
|
c, mh, mw = segmask.shape # CHW
|
|
ih, iw = shape
|
|
width_ratio = mw / iw
|
|
height_ratio = mh / ih
|
|
|
|
downsampled_bboxes = bboxes.clone()
|
|
downsampled_bboxes[:, 0] *= width_ratio
|
|
downsampled_bboxes[:, 2] *= width_ratio
|
|
downsampled_bboxes[:, 3] *= height_ratio
|
|
downsampled_bboxes[:, 1] *= height_ratio
|
|
|
|
segmask = ops.crop_mask(segmask, downsampled_bboxes) # CHW
|
|
if upsample:
|
|
segmask = F.interpolate(segmask[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
|
return segmask.gt_(0)
|
|
|
|
def _prepare_pred(self, inputx, pbatch):
|
|
"""Prepares a batch for training or inference by processing images and targets."""
|
|
pred, pred_masks = None,None
|
|
if isinstance(inputx,list):
|
|
pred, pred_masks = inputx
|
|
else:
|
|
print("error!!!!!!!")
|
|
predn = pred.clone()
|
|
predn[:, :4] = ops.scale_boxes(
|
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
|
) # native-space pred
|
|
pred_masks = self.process_mask(pred_masks, pred[:, :4], pbatch["imgsz"])
|
|
return predn, pred_masks
|
|
|
|
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
|
"""
|
|
Return correct prediction matrix.
|
|
|
|
Args:
|
|
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
|
labels (array[M, 5]), class, x1, y1, x2, y2
|
|
|
|
Returns:
|
|
correct (array[N, 10]), for 10 IoU levels
|
|
"""
|
|
if masks:
|
|
if overlap:
|
|
nl = len(gt_cls)
|
|
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
|
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
|
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
|
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
|
gt_masks = gt_masks.gt_(0.5)
|
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
|
else: # boxes
|
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
|
|
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
|
def update_metrics(self, preds, batch):
|
|
"""Metrics."""
|
|
# box + score + label
|
|
# preds [[1,300,6],[1,300,96,168]]
|
|
for si, (pred, pred_masks) in enumerate(zip(preds[0], preds[1])):
|
|
self.seen += 1
|
|
npr = len(pred) # 300
|
|
stat = dict(
|
|
conf=torch.zeros(0, device=self.device),
|
|
pred_cls=torch.zeros(0, device=self.device),
|
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
)
|
|
pbatch = self._prepare_batch(si, batch)
|
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
|
nl = len(cls)
|
|
stat["target_cls"] = cls
|
|
if npr == 0:
|
|
if nl:
|
|
for k in self.stats.keys():
|
|
self.stats[k].append(stat[k])
|
|
if self.args.plots:
|
|
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
|
continue
|
|
|
|
# Masks
|
|
gt_masks = pbatch.pop("masks")
|
|
# Predictions
|
|
if self.args.single_cls:
|
|
pred[:, 5] = 0
|
|
predn,pred_masks = self._prepare_pred([pred,pred_masks], pbatch)
|
|
|
|
stat["conf"] = predn[:, 4]
|
|
stat["pred_cls"] = predn[:, 5]
|
|
|
|
# Evaluate
|
|
if nl:
|
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
|
stat["tp_m"] = self._process_batch(
|
|
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
|
|
)
|
|
if self.args.plots:
|
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
|
|
|
for k in self.stats.keys():
|
|
self.stats[k].append(stat[k])
|
|
|
|
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
|
if self.args.plots and self.batch_i < 3:
|
|
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
|
|
|
# Save
|
|
if False:# self.args.save_json:
|
|
pred_masks = ops.scale_image(
|
|
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
|
pbatch["ori_shape"],
|
|
ratio_pad=batch["ratio_pad"][si],
|
|
)
|
|
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
|
|
|
|
"""class YOLOv10SegValidator(DetectionValidator):
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
if overrides is None:
|
|
overrides = {}
|
|
overrides["task"] = "segment"
|
|
super().__init__(cfg, overrides, _callbacks)
|
|
|
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
if weights:
|
|
model.load(weights)
|
|
|
|
return model
|
|
|
|
def get_validator(self):
|
|
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
|
return yolo.segment.SegmentationValidator(
|
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
)
|
|
|
|
def plot_training_samples(self, batch, ni):
|
|
plot_images(
|
|
batch["img"],
|
|
batch["batch_idx"],
|
|
batch["cls"].squeeze(-1),
|
|
batch["bboxes"],
|
|
masks=batch["masks"],
|
|
paths=batch["im_file"],
|
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
on_plot=self.on_plot,
|
|
)
|
|
|
|
def plot_metrics(self):
|
|
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png"""
|