mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Prevent duplicate detections caused by the postprocess
This commit is contained in:
parent
6fbaf42b23
commit
dd11a107ce
@ -848,7 +848,7 @@ def clean_str(s):
|
||||
"""
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
def v10postprocess(preds, max_det, nc=80):
|
||||
def origin_v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores = scores.amax(dim=-1)
|
||||
@ -862,3 +862,57 @@ def v10postprocess(preds, max_det, nc=80):
|
||||
index = index // nc
|
||||
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
return boxes, scores, labels
|
||||
|
||||
def v10postprocess(preds, max_det, nc=80):
|
||||
assert(4 + nc == preds.shape[-1])
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
max_scores, scores_index = scores.max(dim=-1)
|
||||
instance_cls, instance_index = torch.topk(max_scores, max_det, dim=-1)
|
||||
boxes = torch.gather(boxes, dim=1, index=instance_index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
||||
labels = torch.gather(scores_index, dim=1, index=instance_index)
|
||||
scores = instance_cls
|
||||
return boxes, scores, labels
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
Suppose the model has 4 detection class and we have 3 detections only
|
||||
The original postprocess function produces the following output:
|
||||
|
||||
tensor([[[3.0000, 4.0000, 5.0000, 6.0000, 0.5000, 3.0000],
|
||||
[3.0000, 4.0000, 5.0000, 6.0000, 0.4000, 2.0000], <- duplicated class
|
||||
[2.0000, 3.0000, 4.0000, 5.0000, 0.3100, 3.0000]]])
|
||||
|
||||
Notice how the detection with two high-confidence classes appears twice, overlapping
|
||||
the third detection. The expected output should be:
|
||||
|
||||
tensor([[[3.0000, 4.0000, 5.0000, 6.0000, 0.5000, 3.0000],
|
||||
[2.0000, 3.0000, 4.0000, 5.0000, 0.3100, 3.0000],
|
||||
[1.0000, 2.0000, 3.0000, 4.0000, 0.3000, 3.0000]]]) <-
|
||||
|
||||
which is obtained by the new postprocess function.
|
||||
|
||||
However, while the new postprocess runs faster, it currently yields a lower mAP.
|
||||
It's unclear if this change was intentional or not.
|
||||
|
||||
Below are some benchmark results:
|
||||
|
||||
yolov10n
|
||||
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████|
|
||||
all 5000 36335 0.649 0.483 0.519 0.372
|
||||
|
||||
yolov10s
|
||||
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████|
|
||||
all 5000 36335 0.706 0.56 0.61 0.449
|
||||
"""
|
||||
|
||||
test_preds = torch.tensor([[
|
||||
[1, 2, 3, 4, 0.2, 0.25, 0.25, 0.3],
|
||||
[2, 3, 4, 5, 0.2, 0.24, 0.25, 0.31],
|
||||
[3, 4, 5, 6, 0.05, 0.05, 0.4, 0.5]
|
||||
]])
|
||||
boxes, scores, labels = origin_v10postprocess(test_preds, max_det=3, nc=4)
|
||||
pred = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||
print(pred)
|
||||
boxes, scores, labels = v10postprocess(test_preds, max_det=3, nc=4)
|
||||
pred = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||
print(pred)
|
Loading…
x
Reference in New Issue
Block a user