From fd013f544228e72da47edba8e7e6f63b6c78fc8c Mon Sep 17 00:00:00 2001 From: laugh12321 Date: Sat, 25 May 2024 10:00:26 +0800 Subject: [PATCH] Added Support ONNX End2End (TRT EfficientNMS) --- ultralytics/engine/exporter.py | 31 +++++++++++++++++- ultralytics/nn/modules/block.py | 56 ++++++++++++++++++++++++++++++++- ultralytics/nn/modules/head.py | 15 +++++++-- 3 files changed, 97 insertions(+), 5 deletions(-) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 1fa3f2e1..fc077e60 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset from ultralytics.data.utils import check_det_dataset from ultralytics.nn.autobackend import check_class_names, default_class_names from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect -from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel +from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel, YOLOv10DetectionModel from ultralytics.utils import ( ARM64, DEFAULT_CFG, @@ -231,6 +231,10 @@ class Exporter: m.format = self.args.format if isinstance(m, v10Detect): m.max_det = self.args.max_det + if self.args.nms and (onnx or engine): + m.end2end = True + m.iou_thres = self.args.iou if self.args.iou else 0.65 + m.conf_thres = self.args.conf if self.args.conf else 0.25 elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)): # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph @@ -271,6 +275,10 @@ class Exporter: "batch": self.args.batch, "imgsz": self.imgsz, "names": model.names, + "nms": int(self.args.nms), # json fails if store as bool value + "max_det": self.args.max_det, + "conf": self.args.conf if self.args.conf else 0.25, + "iou": self.args.iou if self.args.iou else 0.65, } # model metadata if model.task == "pose": self.metadata["kpt_shape"] = model.model[-1].kpt_shape @@ -366,6 +374,16 @@ class Exporter: f = str(self.file.with_suffix(".onnx")) output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"] + + if self.args.nms and isinstance(self.model, YOLOv10DetectionModel): + output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes'] + shapes = { + 'num_dets': ["batch" if self.args.dynamic else self.args.batch, 1], + 'det_boxes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det, 4], + 'det_scores': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det], + 'det_classes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det], + } + dynamic = self.args.dynamic if dynamic: dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640) @@ -374,6 +392,11 @@ class Exporter: dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160) elif isinstance(self.model, DetectionModel): dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400) + if self.args.nms and isinstance(self.model, YOLOv10DetectionModel): + dynamic["num_dets"] = {0: "batch"} + dynamic["det_boxes"] = {0: "batch"} + dynamic["det_scores"] = {0: "batch"} + dynamic["det_classes"] = {0: "batch"} torch.onnx.export( self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu @@ -391,6 +414,12 @@ class Exporter: model_onnx = onnx.load(f) # load onnx model # onnx.checker.check_model(model_onnx) # check onnx model + if self.args.nms and isinstance(self.model, YOLOv10DetectionModel): + for output in model_onnx.graph.output: + for idx, dim in enumerate(output.type.tensor_type.shape.dim): + if output.name in shapes and len(shapes[output.name]): + dim.dim_param = str(shapes[output.name][idx]) + # Simplify if self.args.simplify: try: diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index b76a9595..dcac7738 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -1,9 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license """Block modules.""" +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F +from torch import Graph, Tensor, Value from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad from .transformer import TransformerBlock @@ -824,4 +827,55 @@ class SCDown(nn.Module): self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False) def forward(self, x): - return self.cv2(self.cv1(x)) \ No newline at end of file + return self.cv2(self.cv1(x)) + +class Efficient_TRT_NMS(torch.autograd.Function): + """NMS block for YOLO-fused model for TensorRT.""" + + @staticmethod + def forward( + ctx: Graph, + boxes: Tensor, + scores: Tensor, + iou_threshold: float = 0.65, + score_threshold: float = 0.25, + max_output_boxes: int = 100, + background_class: int = -1, + box_coding: int = 0, + plugin_version: str = "1", + score_activation: int = 0, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + batch_size, num_boxes, num_classes = scores.shape + num_dets = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) + boxes = torch.randn(batch_size, max_output_boxes, 4) + scores = torch.randn(batch_size, max_output_boxes) + labels = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) + + return num_dets, boxes, scores, labels + + @staticmethod + def symbolic( + g, + boxes: Value, + scores: Value, + iou_threshold: float = 0.65, + score_threshold: float = 0.25, + max_output_boxes: int = 100, + background_class: int = -1, + box_coding: int = 0, + plugin_version: str = "1", + score_activation: int = 0, + ) -> Tuple[Value, Value, Value, Value]: + return g.op( + "TRT::EfficientNMS_TRT", + boxes, + scores, + iou_threshold_f=iou_threshold, + score_threshold_f=score_threshold, + max_output_boxes_i=max_output_boxes, + background_class_i=background_class, + box_coding_i=box_coding, + plugin_version_s=plugin_version, + score_activation_i=score_activation, + outputs=4, + ) \ No newline at end of file diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 5bc7c068..95f358be 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -8,7 +8,7 @@ import torch.nn as nn from torch.nn.init import constant_, xavier_uniform_ from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors -from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead +from .block import DFL, Proto, Efficient_TRT_NMS, ContrastiveHead, BNContrastiveHead from .conv import Conv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .utils import bias_init_with_prob, linear_init @@ -496,7 +496,10 @@ class RTDETRDecoder(nn.Module): class v10Detect(Detect): + end2end = False max_det = -1 + iou_thres = 0.65 + conf_thres = 0.25 def __init__(self, nc=80, ch=()): super().__init__(nc, ch) @@ -519,8 +522,14 @@ class v10Detect(Detect): return {"one2many": one2many, "one2one": one2one} else: assert(self.max_det != -1) - boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc) - return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + if self.end2end: + preds = one2one.permute(0, 2, 1) + assert(4 + self.nc == preds.shape[-1]) + boxes, scores = preds.split([4, self.nc], dim=-1) + return Efficient_TRT_NMS.apply(boxes, scores, self.iou_thres, self.conf_thres, self.max_det) + else: + boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc) + return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) else: return {"one2many": one2many, "one2one": one2one}