mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Improve NMS speed (#3467)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
586c95b8a9
commit
5e38c7d71b
@ -200,12 +200,16 @@ def non_max_suppression(
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
merge = False # use merge-NMS
|
||||
|
||||
prediction = prediction.clone() # don't modify original
|
||||
prediction = prediction.transpose(-1, -2) # to (batch, boxes, items)
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
x = x.transpose(0, -1)[xc[xi]] # confidence
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
@ -221,9 +225,9 @@ def non_max_suppression(
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, nm), 1)
|
||||
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
||||
|
||||
if multi_label:
|
||||
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
@ -241,6 +245,8 @@ def non_max_suppression(
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
|
||||
if n > max_nms: # excess boxes
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
||||
|
||||
# Batched NMS
|
||||
|
Loading…
x
Reference in New Issue
Block a user