Added Support ONNX End2End (TRT EfficientNMS)

This commit is contained in:
laugh12321 2024-05-25 10:00:26 +08:00
parent 74f62fb017
commit fd013f5442
3 changed files with 97 additions and 5 deletions

View File

@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel, YOLOv10DetectionModel
from ultralytics.utils import (
ARM64,
DEFAULT_CFG,
@ -231,6 +231,10 @@ class Exporter:
m.format = self.args.format
if isinstance(m, v10Detect):
m.max_det = self.args.max_det
if self.args.nms and (onnx or engine):
m.end2end = True
m.iou_thres = self.args.iou if self.args.iou else 0.65
m.conf_thres = self.args.conf if self.args.conf else 0.25
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
@ -271,6 +275,10 @@ class Exporter:
"batch": self.args.batch,
"imgsz": self.imgsz,
"names": model.names,
"nms": int(self.args.nms), # json fails if store as bool value
"max_det": self.args.max_det,
"conf": self.args.conf if self.args.conf else 0.25,
"iou": self.args.iou if self.args.iou else 0.65,
} # model metadata
if model.task == "pose":
self.metadata["kpt_shape"] = model.model[-1].kpt_shape
@ -366,6 +374,16 @@ class Exporter:
f = str(self.file.with_suffix(".onnx"))
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = {
'num_dets': ["batch" if self.args.dynamic else self.args.batch, 1],
'det_boxes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det, 4],
'det_scores': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
'det_classes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
}
dynamic = self.args.dynamic
if dynamic:
dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
@ -374,6 +392,11 @@ class Exporter:
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
elif isinstance(self.model, DetectionModel):
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
dynamic["num_dets"] = {0: "batch"}
dynamic["det_boxes"] = {0: "batch"}
dynamic["det_scores"] = {0: "batch"}
dynamic["det_classes"] = {0: "batch"}
torch.onnx.export(
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
@ -391,6 +414,12 @@ class Exporter:
model_onnx = onnx.load(f) # load onnx model
# onnx.checker.check_model(model_onnx) # check onnx model
if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
for output in model_onnx.graph.output:
for idx, dim in enumerate(output.type.tensor_type.shape.dim):
if output.name in shapes and len(shapes[output.name]):
dim.dim_param = str(shapes[output.name][idx])
# Simplify
if self.args.simplify:
try:

View File

@ -1,9 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Block modules."""
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Graph, Tensor, Value
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
from .transformer import TransformerBlock
@ -824,4 +827,55 @@ class SCDown(nn.Module):
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
def forward(self, x):
return self.cv2(self.cv1(x))
return self.cv2(self.cv1(x))
class Efficient_TRT_NMS(torch.autograd.Function):
"""NMS block for YOLO-fused model for TensorRT."""
@staticmethod
def forward(
ctx: Graph,
boxes: Tensor,
scores: Tensor,
iou_threshold: float = 0.65,
score_threshold: float = 0.25,
max_output_boxes: int = 100,
background_class: int = -1,
box_coding: int = 0,
plugin_version: str = "1",
score_activation: int = 0,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
batch_size, num_boxes, num_classes = scores.shape
num_dets = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
boxes = torch.randn(batch_size, max_output_boxes, 4)
scores = torch.randn(batch_size, max_output_boxes)
labels = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_dets, boxes, scores, labels
@staticmethod
def symbolic(
g,
boxes: Value,
scores: Value,
iou_threshold: float = 0.65,
score_threshold: float = 0.25,
max_output_boxes: int = 100,
background_class: int = -1,
box_coding: int = 0,
plugin_version: str = "1",
score_activation: int = 0,
) -> Tuple[Value, Value, Value, Value]:
return g.op(
"TRT::EfficientNMS_TRT",
boxes,
scores,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold,
max_output_boxes_i=max_output_boxes,
background_class_i=background_class,
box_coding_i=box_coding,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
outputs=4,
)

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
from .block import DFL, Proto, Efficient_TRT_NMS, ContrastiveHead, BNContrastiveHead
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init
@ -496,7 +496,10 @@ class RTDETRDecoder(nn.Module):
class v10Detect(Detect):
end2end = False
max_det = -1
iou_thres = 0.65
conf_thres = 0.25
def __init__(self, nc=80, ch=()):
super().__init__(nc, ch)
@ -519,8 +522,14 @@ class v10Detect(Detect):
return {"one2many": one2many, "one2one": one2one}
else:
assert(self.max_det != -1)
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
if self.end2end:
preds = one2one.permute(0, 2, 1)
assert(4 + self.nc == preds.shape[-1])
boxes, scores = preds.split([4, self.nc], dim=-1)
return Efficient_TRT_NMS.apply(boxes, scores, self.iou_thres, self.conf_thres, self.max_det)
else:
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
else:
return {"one2many": one2many, "one2one": one2one}