mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Clean validator (#144)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
46cb657b64
commit
19334ebb16
@ -463,6 +463,8 @@ class LetterBox:
|
|||||||
|
|
||||||
dw /= 2 # divide padding into 2 sides
|
dw /= 2 # divide padding into 2 sides
|
||||||
dh /= 2
|
dh /= 2
|
||||||
|
if labels.get("ratio_pad"):
|
||||||
|
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
|
||||||
|
|
||||||
if shape[::-1] != new_unpad: # resize
|
if shape[::-1] != new_unpad: # resize
|
||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
|
@ -179,6 +179,10 @@ class BaseDataset(Dataset):
|
|||||||
def get_label_info(self, index):
|
def get_label_info(self, index):
|
||||||
label = self.labels[index].copy()
|
label = self.labels[index].copy()
|
||||||
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
||||||
|
label["ratio_pad"] = (
|
||||||
|
label["resized_shape"][0] / label["ori_shape"][0],
|
||||||
|
label["resized_shape"][1] / label["ori_shape"][1],
|
||||||
|
) # for evaluation
|
||||||
if self.rect:
|
if self.rect:
|
||||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||||
label = self.update_labels_info(label)
|
label = self.update_labels_info(label)
|
||||||
|
@ -895,7 +895,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
batch_idx, cls, bboxes = torch.cat(label, 0).split((1, 1, 4), dim=1)
|
batch_idx, cls, bboxes = torch.cat(label, 0).split((1, 1, 4), dim=1)
|
||||||
return {
|
return {
|
||||||
'ori_shape': tuple((x[0] if x else None) for x in shapes),
|
'ori_shape': tuple((x[0] if x else None) for x in shapes),
|
||||||
'resized_shape': tuple(tuple(x.shape[1:]) for x in im),
|
'ratio_pad': tuple((x[1] if x else None) for x in shapes),
|
||||||
'im_file': path,
|
'im_file': path,
|
||||||
'img': torch.stack(im, 0),
|
'img': torch.stack(im, 0),
|
||||||
'cls': cls,
|
'cls': cls,
|
||||||
|
@ -127,7 +127,7 @@ class YOLODataset(BaseDataset):
|
|||||||
mosaic = self.augment and not self.rect
|
mosaic = self.augment and not self.rect
|
||||||
transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
|
transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
|
||||||
else:
|
else:
|
||||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||||
transforms.append(
|
transforms.append(
|
||||||
Format(bbox_format="xywh",
|
Format(bbox_format="xywh",
|
||||||
normalize=True,
|
normalize=True,
|
||||||
|
@ -224,7 +224,7 @@ class BaseTrainer:
|
|||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
self.resume_training(ckpt)
|
self.resume_training(ckpt)
|
||||||
|
@ -469,7 +469,7 @@ class Metric:
|
|||||||
|
|
||||||
def mean_results(self):
|
def mean_results(self):
|
||||||
"""Mean of results, return mp, mr, map50, map"""
|
"""Mean of results, return mp, mr, map50, map"""
|
||||||
return self.mp, self.mr, self.map50, self.map
|
return [self.mp, self.mr, self.map50, self.map]
|
||||||
|
|
||||||
def class_result(self, i):
|
def class_result(self, i):
|
||||||
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
|
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
|
||||||
@ -520,6 +520,7 @@ class DetMetrics:
|
|||||||
def get_maps(self, nc):
|
def get_maps(self, nc):
|
||||||
return self.metric.get_maps(nc)
|
return self.metric.get_maps(nc)
|
||||||
|
|
||||||
|
@property
|
||||||
def fitness(self):
|
def fitness(self):
|
||||||
return self.metric.fitness()
|
return self.metric.fitness()
|
||||||
|
|
||||||
@ -527,6 +528,10 @@ class DetMetrics:
|
|||||||
def ap_class_index(self):
|
def ap_class_index(self):
|
||||||
return self.metric.ap_class_index
|
return self.metric.ap_class_index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results_dict(self):
|
||||||
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
||||||
|
|
||||||
|
|
||||||
class SegmentMetrics:
|
class SegmentMetrics:
|
||||||
|
|
||||||
@ -578,6 +583,7 @@ class SegmentMetrics:
|
|||||||
def get_maps(self, nc):
|
def get_maps(self, nc):
|
||||||
return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
|
return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
|
||||||
|
|
||||||
|
@property
|
||||||
def fitness(self):
|
def fitness(self):
|
||||||
return self.metric_mask.fitness() + self.metric_box.fitness()
|
return self.metric_mask.fitness() + self.metric_box.fitness()
|
||||||
|
|
||||||
@ -585,3 +591,30 @@ class SegmentMetrics:
|
|||||||
def ap_class_index(self):
|
def ap_class_index(self):
|
||||||
# boxes and masks have the same ap_class_index
|
# boxes and masks have the same ap_class_index
|
||||||
return self.metric_box.ap_class_index
|
return self.metric_box.ap_class_index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results_dict(self):
|
||||||
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyMetrics:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.top1 = 0
|
||||||
|
self.top5 = 0
|
||||||
|
|
||||||
|
def process(self, correct):
|
||||||
|
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||||
|
self.top1, self.top5 = acc.mean(0).tolist()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fitness(self):
|
||||||
|
return self.top5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results_dict(self):
|
||||||
|
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self):
|
||||||
|
return ["top1", "top5"]
|
||||||
|
@ -4,10 +4,15 @@ import torch
|
|||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||||
|
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
||||||
|
|
||||||
|
|
||||||
class ClassificationValidator(BaseValidator):
|
class ClassificationValidator(BaseValidator):
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
|
self.metrics = ClassifyMetrics()
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
self.correct = torch.tensor([], device=next(model.parameters()).device)
|
self.correct = torch.tensor([], device=next(model.parameters()).device)
|
||||||
|
|
||||||
@ -23,17 +28,12 @@ class ClassificationValidator(BaseValidator):
|
|||||||
self.correct = torch.cat((self.correct, correct_in_batch))
|
self.correct = torch.cat((self.correct, correct_in_batch))
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
self.metrics.process(self.correct)
|
||||||
top1, top5 = acc.mean(0).tolist()
|
return self.metrics.results_dict
|
||||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
|
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
|
||||||
|
|
||||||
@property
|
|
||||||
def metric_keys(self):
|
|
||||||
return ["top1", "top5"]
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||||
def val(cfg):
|
def val(cfg):
|
||||||
|
@ -22,7 +22,6 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
self.targets = None
|
|
||||||
self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots)
|
self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots)
|
||||||
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
||||||
self.niou = self.iouv.numel()
|
self.niou = self.iouv.numel()
|
||||||
@ -30,13 +29,13 @@ class DetectionValidator(BaseValidator):
|
|||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||||
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
|
for k in ["batch_idx", "cls", "bboxes"]:
|
||||||
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
batch[k] = batch[k].to(self.device)
|
||||||
self.targets = self.targets.to(self.device)
|
|
||||||
height, width = batch["img"].shape[2:]
|
nb, _, height, width = batch["img"].shape
|
||||||
self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device) # to pixels
|
batch["bboxes"] *= torch.tensor((width, height, width, height), device=self.device) # to pixels
|
||||||
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
|
self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
|
||||||
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
|
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -69,36 +68,39 @@ class DetectionValidator(BaseValidator):
|
|||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
for si, pred in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
idx = batch["batch_idx"] == si
|
||||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
cls = batch["cls"][idx]
|
||||||
|
bbox = batch["bboxes"][idx]
|
||||||
|
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = batch["ori_shape"][si]
|
shape = batch["ori_shape"][si]
|
||||||
# path = batch["shape"][si][0]
|
|
||||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
|
|
||||||
if npr == 0:
|
if npr == 0:
|
||||||
if nl:
|
if nl:
|
||||||
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), labels[:, 0]))
|
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
|
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape) # native-space pred
|
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
|
||||||
|
ratio_pad=batch["ratio_pad"][si]) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = ops.xywh2xyxy(bbox) # target boxes
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
|
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape,
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
ratio_pad=batch["ratio_pad"][si]) # native-space labels
|
||||||
|
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||||
correct_bboxes = self._process_batch(predn, labelsn)
|
correct_bboxes = self._process_batch(predn, labelsn)
|
||||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(predn, labelsn)
|
self.confusion_matrix.process_batch(predn, labelsn)
|
||||||
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
|
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
@ -111,12 +113,10 @@ class DetectionValidator(BaseValidator):
|
|||||||
if len(stats) and stats[0].any():
|
if len(stats) and stats[0].any():
|
||||||
self.metrics.process(*stats)
|
self.metrics.process(*stats)
|
||||||
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
|
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
|
||||||
fitness = {"fitness": self.metrics.fitness()}
|
return self.metrics.results_dict
|
||||||
metrics = dict(zip(self.metric_keys, self.metrics.mean_results()))
|
|
||||||
return {**metrics, **fitness}
|
|
||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metric_keys) # print format
|
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||||
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
if self.nt_per_class.sum() == 0:
|
if self.nt_per_class.sum() == 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
@ -166,18 +166,13 @@ class DetectionValidator(BaseValidator):
|
|||||||
hyp=dict(self.args),
|
hyp=dict(self.args),
|
||||||
cache=False,
|
cache=False,
|
||||||
pad=0.5,
|
pad=0.5,
|
||||||
rect=self.args.rect,
|
rect=True,
|
||||||
workers=self.args.workers,
|
workers=self.args.workers,
|
||||||
prefix=colorstr(f'{self.args.mode}: '),
|
prefix=colorstr(f'{self.args.mode}: '),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=self.args.seed)[0] if self.args.v5loader else \
|
seed=self.args.seed)[0] if self.args.v5loader else \
|
||||||
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
|
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
|
||||||
|
|
||||||
# TODO: align with train loss metrics
|
|
||||||
@property
|
|
||||||
def metric_keys(self):
|
|
||||||
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch["img"],
|
||||||
batch["batch_idx"],
|
batch["batch_idx"],
|
||||||
@ -226,7 +221,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
eval.evaluate()
|
eval.evaluate()
|
||||||
eval.accumulate()
|
eval.accumulate()
|
||||||
eval.summarize()
|
eval.summarize()
|
||||||
stats[self.metric_keys[-1]], stats[self.metric_keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
self.logger.warning(f'pycocotools unable to run: {e}')
|
||||||
return stats
|
return stats
|
||||||
|
@ -22,17 +22,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch = super().preprocess(batch)
|
||||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
|
||||||
batch["masks"] = batch["masks"].to(self.device).float()
|
batch["masks"] = batch["masks"].to(self.device).float()
|
||||||
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
|
|
||||||
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
||||||
self.targets = self.targets.to(self.device)
|
|
||||||
height, width = batch["img"].shape[2:]
|
|
||||||
self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device) # to pixels
|
|
||||||
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
|
|
||||||
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
|
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
@ -72,10 +63,11 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
idx = batch["batch_idx"] == si
|
||||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
cls = batch["cls"][idx]
|
||||||
|
bbox = batch["bboxes"][idx]
|
||||||
|
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = batch["ori_shape"][si]
|
shape = batch["ori_shape"][si]
|
||||||
# path = batch["shape"][si][0]
|
|
||||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
self.seen += 1
|
self.seen += 1
|
||||||
@ -83,13 +75,13 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
if npr == 0:
|
if npr == 0:
|
||||||
if nl:
|
if nl:
|
||||||
self.stats.append((correct_masks, correct_bboxes, *torch.zeros(
|
self.stats.append((correct_masks, correct_bboxes, *torch.zeros(
|
||||||
(2, 0), device=self.device), labels[:, 0]))
|
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
|
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Masks
|
# Masks
|
||||||
midx = [si] if self.args.overlap_mask else self.targets[:, 0] == si
|
midx = [si] if self.args.overlap_mask else idx
|
||||||
gt_masks = batch["masks"][midx]
|
gt_masks = batch["masks"][midx]
|
||||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch["img"][si].shape[1:])
|
||||||
|
|
||||||
@ -101,9 +93,9 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = ops.xywh2xyxy(bbox) # target boxes
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
|
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||||
correct_bboxes = self._process_batch(predn, labelsn)
|
correct_bboxes = self._process_batch(predn, labelsn)
|
||||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||||
correct_masks = self._process_batch(predn,
|
correct_masks = self._process_batch(predn,
|
||||||
@ -114,7 +106,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
masks=True)
|
masks=True)
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(predn, labelsn)
|
self.confusion_matrix.process_batch(predn, labelsn)
|
||||||
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # conf, pcls, tcls
|
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:,
|
||||||
|
5], cls.squeeze(-1))) # conf, pcls, tcls
|
||||||
|
|
||||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||||
if self.args.plots and self.batch_i < 3:
|
if self.args.plots and self.batch_i < 3:
|
||||||
@ -165,19 +158,6 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
correct[matches[:, 1].astype(int), i] = True
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||||
|
|
||||||
# TODO: probably add this to class Metrics
|
|
||||||
@property
|
|
||||||
def metric_keys(self):
|
|
||||||
return [
|
|
||||||
"metrics/precision(B)",
|
|
||||||
"metrics/recall(B)",
|
|
||||||
"metrics/mAP50(B)",
|
|
||||||
"metrics/mAP50-95(B)", # metrics
|
|
||||||
"metrics/precision(M)",
|
|
||||||
"metrics/recall(M)",
|
|
||||||
"metrics/mAP50(M)",
|
|
||||||
"metrics/mAP50-95(M)",]
|
|
||||||
|
|
||||||
def plot_val_samples(self, batch, ni):
|
def plot_val_samples(self, batch, ni):
|
||||||
plot_images(batch["img"],
|
plot_images(batch["img"],
|
||||||
batch["batch_idx"],
|
batch["batch_idx"],
|
||||||
@ -243,8 +223,8 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
eval.accumulate()
|
eval.accumulate()
|
||||||
eval.summarize()
|
eval.summarize()
|
||||||
idx = i * 4 + 2
|
idx = i * 4 + 2
|
||||||
stats[self.metric_keys[idx + 1]], stats[
|
stats[self.metrics.keys[idx + 1]], stats[
|
||||||
self.metric_keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
self.logger.warning(f'pycocotools unable to run: {e}')
|
||||||
return stats
|
return stats
|
||||||
|
Loading…
x
Reference in New Issue
Block a user