from ultralytics.models.yolo.detect import DetectionValidator from ultralytics.utils import ops import torch from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou import torch.nn.functional as F class YOLOv10SegValidator(DetectionValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): super().__init__(dataloader, save_dir, pbar, args, _callbacks) self.args.save_json |= self.is_coco self.plot_masks = None self.process = None self.args.task = "segment" self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) def init_metrics(self, model): """Initialize metrics and select mask processing function based on save_json flag.""" super().init_metrics(model) self.plot_masks = [] if self.args.save_json: # check_requirements("pycocotools>=2.0.6") self.process = ops.process_mask_upsample # more accurate else: self.process = ops.process_mask # faster self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[]) def get_desc(self): """Return a formatted description of evaluation metrics.""" return ("%22s" + "%11s" * 10) % ( "Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)", "Mask(P", "R", "mAP50", "mAP50-95)", ) def _prepare_batch(self, si, batch): """Prepares a batch for training or inference by processing images and targets.""" prepared_batch = super()._prepare_batch(si, batch) midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si prepared_batch["masks"] = batch["masks"][midx] return prepared_batch def finalize_metrics(self, *args, **kwargs): """Sets speed and confusion matrix for evaluation metrics.""" self.metrics.speed = self.speed self.metrics.confusion_matrix = self.confusion_matrix def preprocess(self, batch): """Preprocesses batch by converting masks to float and sending to device.""" batch = super().preprocess(batch) batch["masks"] = batch["masks"].to(self.device).float() return batch def postprocess(self, preds): coef,proto = None,None if isinstance(preds, dict): coef = preds["coef"] # [1,32,5294] proto = preds["proto"] # [1,32,92,168] preds = preds["one2one"] if isinstance(preds, (list, tuple)): preds = preds[0] # Acknowledgement: Thanks to sanha9999 in #190 and #181! if preds.shape[-1] == 6: return preds else: preds = preds.transpose(-1, -2) # [1,6006,5] coef = coef.transpose(-1, -2) # bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, preds.shape[-1]-4) bboxes, scores, labels,segmask = ops.v10segpostprocess([preds,coef,proto], self.args.max_det, preds.shape[-1]-4) bboxes = ops.xywh2xyxy(bboxes) preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) return [preds,segmask] def process_mask(self,segmask, bboxes, shape, upsample=False): c, mh, mw = segmask.shape # CHW ih, iw = shape width_ratio = mw / iw height_ratio = mh / ih downsampled_bboxes = bboxes.clone() downsampled_bboxes[:, 0] *= width_ratio downsampled_bboxes[:, 2] *= width_ratio downsampled_bboxes[:, 3] *= height_ratio downsampled_bboxes[:, 1] *= height_ratio segmask = ops.crop_mask(segmask, downsampled_bboxes) # CHW if upsample: segmask = F.interpolate(segmask[None], shape, mode="bilinear", align_corners=False)[0] # CHW return segmask.gt_(0) def _prepare_pred(self, inputx, pbatch): """Prepares a batch for training or inference by processing images and targets.""" pred, pred_masks = None,None if isinstance(inputx,list): pred, pred_masks = inputx else: print("error!!!!!!!") predn = pred.clone() predn[:, :4] = ops.scale_boxes( pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] ) # native-space pred pred_masks = self.process_mask(pred_masks, pred[:, :4], pbatch["imgsz"]) return predn, pred_masks def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ Return correct prediction matrix. Args: detections (array[N, 6]), x1, y1, x2, y2, conf, class labels (array[M, 5]), class, x1, y1, x2, y2 Returns: correct (array[N, 10]), for 10 IoU levels """ if masks: if overlap: nl = len(gt_cls) index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) gt_masks = torch.where(gt_masks == index, 1.0, 0.0) if gt_masks.shape[1:] != pred_masks.shape[1:]: gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] gt_masks = gt_masks.gt_(0.5) iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) else: # boxes iou = box_iou(gt_bboxes, detections[:, :4]) return self.match_predictions(detections[:, 5], gt_cls, iou) def update_metrics(self, preds, batch): """Metrics.""" # box + score + label # preds [[1,300,6],[1,300,96,168]] for si, (pred, pred_masks) in enumerate(zip(preds[0], preds[1])): self.seen += 1 npr = len(pred) # 300 stat = dict( conf=torch.zeros(0, device=self.device), pred_cls=torch.zeros(0, device=self.device), tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), ) pbatch = self._prepare_batch(si, batch) cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") nl = len(cls) stat["target_cls"] = cls if npr == 0: if nl: for k in self.stats.keys(): self.stats[k].append(stat[k]) if self.args.plots: self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) continue # Masks gt_masks = pbatch.pop("masks") # Predictions if self.args.single_cls: pred[:, 5] = 0 predn,pred_masks = self._prepare_pred([pred,pred_masks], pbatch) stat["conf"] = predn[:, 4] stat["pred_cls"] = predn[:, 5] # Evaluate if nl: stat["tp"] = self._process_batch(predn, bbox, cls) stat["tp_m"] = self._process_batch( predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True ) if self.args.plots: self.confusion_matrix.process_batch(predn, bbox, cls) for k in self.stats.keys(): self.stats[k].append(stat[k]) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) if self.args.plots and self.batch_i < 3: self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot # Save if False:# self.args.save_json: pred_masks = ops.scale_image( pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), pbatch["ori_shape"], ratio_pad=batch["ratio_pad"][si], ) self.pred_to_json(predn, batch["im_file"][si], pred_masks) """class YOLOv10SegValidator(DetectionValidator): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): if overrides is None: overrides = {} overrides["task"] = "segment" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg=None, weights=None, verbose=True): model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def get_validator(self): self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" return yolo.segment.SegmentationValidator( self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks ) def plot_training_samples(self, batch, ni): plot_images( batch["img"], batch["batch_idx"], batch["cls"].squeeze(-1), batch["bboxes"], masks=batch["masks"], paths=batch["im_file"], fname=self.save_dir / f"train_batch{ni}.jpg", on_plot=self.on_plot, ) def plot_metrics(self): plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png"""