mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
[rename] - preprocess-batch -> preprocess, preprocess_preds -> postprocess (#42)
This commit is contained in:
parent
4c68b9dcf6
commit
d143ac666f
@ -54,7 +54,7 @@ class BaseValidator:
|
|||||||
self.batch_i = batch_i
|
self.batch_i = batch_i
|
||||||
# pre-process
|
# pre-process
|
||||||
with dt[0]:
|
with dt[0]:
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess(batch)
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
@ -69,7 +69,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
# pre-process predictions
|
# pre-process predictions
|
||||||
with dt[3]:
|
with dt[3]:
|
||||||
preds = self.preprocess_preds(preds)
|
preds = self.postprocess(preds)
|
||||||
|
|
||||||
self.update_metrics(preds, batch)
|
self.update_metrics(preds, batch)
|
||||||
|
|
||||||
@ -89,10 +89,10 @@ class BaseValidator:
|
|||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess(self, batch):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def preprocess_preds(self, preds):
|
def postprocess(self, preds):
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def init_metrics(self):
|
def init_metrics(self):
|
||||||
|
@ -8,7 +8,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
self.correct = torch.tensor([])
|
self.correct = torch.tensor([])
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device)
|
batch["img"] = batch["img"].to(self.device)
|
||||||
batch["cls"] = batch["cls"].to(self.device)
|
batch["cls"] = batch["cls"].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
@ -28,7 +28,7 @@ class SegmentationValidator(BaseValidator):
|
|||||||
self.class_map = None
|
self.class_map = None
|
||||||
self.targets = None
|
self.targets = None
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
|
||||||
batch["bboxes"] = batch["bboxes"].to(self.device)
|
batch["bboxes"] = batch["bboxes"].to(self.device)
|
||||||
@ -66,7 +66,7 @@ class SegmentationValidator(BaseValidator):
|
|||||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
||||||
"R", "mAP50", "mAP50-95)")
|
"R", "mAP50", "mAP50-95)")
|
||||||
|
|
||||||
def preprocess_preds(self, preds):
|
def postprocess(self, preds):
|
||||||
p = ops.non_max_suppression(preds[0],
|
p = ops.non_max_suppression(preds[0],
|
||||||
self.args.conf_thres,
|
self.args.conf_thres,
|
||||||
self.args.iou_thres,
|
self.args.iou_thres,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user