mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix save_hybrid
(#4245)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7dfdb63cde
commit
e9c9b82c42
@ -26,6 +26,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
||||||
self.niou = self.iouv.numel()
|
self.niou = self.iouv.numel()
|
||||||
|
self.lb = [] # for autolabelling
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
"""Preprocesses batch of images for YOLO training."""
|
"""Preprocesses batch of images for YOLO training."""
|
||||||
@ -34,8 +35,12 @@ class DetectionValidator(BaseValidator):
|
|||||||
for k in ['batch_idx', 'cls', 'bboxes']:
|
for k in ['batch_idx', 'cls', 'bboxes']:
|
||||||
batch[k] = batch[k].to(self.device)
|
batch[k] = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.args.save_hybrid:
|
||||||
|
height, width = batch['img'].shape[2:]
|
||||||
nb = len(batch['img'])
|
nb = len(batch['img'])
|
||||||
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
|
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
|
||||||
|
self.lb = [
|
||||||
|
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
|
||||||
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
@ -225,8 +225,8 @@ def non_max_suppression(
|
|||||||
# Cat apriori labels if autolabelling
|
# Cat apriori labels if autolabelling
|
||||||
if labels and len(labels[xi]):
|
if labels and len(labels[xi]):
|
||||||
lb = labels[xi]
|
lb = labels[xi]
|
||||||
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
|
||||||
v[:, :4] = lb[:, 1:5] # box
|
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||||
x = torch.cat((x, v), 0)
|
x = torch.cat((x, v), 0)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user