mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix save_json(predn, batch)
(#105)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8b6466f731
commit
38d6df55cb
@ -109,9 +109,6 @@ class BaseValidator:
|
|||||||
self.plot_val_samples(batch, batch_i)
|
self.plot_val_samples(batch, batch_i)
|
||||||
self.plot_predictions(batch, preds, batch_i)
|
self.plot_predictions(batch, preds, batch_i)
|
||||||
|
|
||||||
if self.args.save_json:
|
|
||||||
self.pred_to_json(preds, batch)
|
|
||||||
|
|
||||||
stats = self.get_stats()
|
stats = self.get_stats()
|
||||||
self.check_stats(stats)
|
self.check_stats(stats)
|
||||||
self.print_results()
|
self.print_results()
|
||||||
@ -126,8 +123,7 @@ class BaseValidator:
|
|||||||
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
||||||
self.logger.info(f"Saving {f.name}...")
|
self.logger.info(f"Saving {f.name}...")
|
||||||
json.dump(self.jdict, f) # flatten and save
|
json.dump(self.jdict, f) # flatten and save
|
||||||
|
stats = self.eval_json(stats) # update stats
|
||||||
stats = self.eval_json(stats)
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
@ -71,7 +71,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
for si, (pred) in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
labels = self.targets[self.targets[:, 0] == si, 1:]
|
||||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = batch["ori_shape"][si]
|
shape = batch["ori_shape"][si]
|
||||||
@ -103,11 +103,11 @@ class DetectionValidator(BaseValidator):
|
|||||||
self.confusion_matrix.process_batch(predn, labelsn)
|
self.confusion_matrix.process_batch(predn, labelsn)
|
||||||
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
|
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
|
||||||
|
|
||||||
# TODO: Save/log
|
# Save
|
||||||
'''
|
if self.args.save_json:
|
||||||
if self.args.save_txt:
|
self.pred_to_json(predn, batch["im_file"][si])
|
||||||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# if self.args.save_txt:
|
||||||
'''
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||||
@ -197,13 +197,12 @@ class DetectionValidator(BaseValidator):
|
|||||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||||
names=self.names) # pred
|
names=self.names) # pred
|
||||||
|
|
||||||
def pred_to_json(self, preds, batch):
|
def pred_to_json(self, predn, filename):
|
||||||
for i, f in enumerate(batch["im_file"]):
|
stem = Path(filename).stem
|
||||||
stem = Path(f).stem
|
|
||||||
image_id = int(stem) if stem.isnumeric() else stem
|
image_id = int(stem) if stem.isnumeric() else stem
|
||||||
box = ops.xyxy2xywh(preds[i][:, :4]) # xywh
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||||
for p, b in zip(preds[i].tolist(), box.tolist()):
|
for p, b in zip(predn.tolist(), box.tolist()):
|
||||||
self.jdict.append({
|
self.jdict.append({
|
||||||
'image_id': image_id,
|
'image_id': image_id,
|
||||||
'category_id': self.class_map[int(p[5])],
|
'category_id': self.class_map[int(p[5])],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user