mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 17:05:40 +08:00 
			
		
		
		
	Added Support ONNX End2End (TRT EfficientNMS)
This commit is contained in:
		
							parent
							
								
									74f62fb017
								
							
						
					
					
						commit
						fd013f5442
					
				@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
 | 
			
		||||
from ultralytics.data.utils import check_det_dataset
 | 
			
		||||
from ultralytics.nn.autobackend import check_class_names, default_class_names
 | 
			
		||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect
 | 
			
		||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
 | 
			
		||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel, YOLOv10DetectionModel
 | 
			
		||||
from ultralytics.utils import (
 | 
			
		||||
    ARM64,
 | 
			
		||||
    DEFAULT_CFG,
 | 
			
		||||
@ -231,6 +231,10 @@ class Exporter:
 | 
			
		||||
                m.format = self.args.format
 | 
			
		||||
                if isinstance(m, v10Detect):
 | 
			
		||||
                    m.max_det = self.args.max_det
 | 
			
		||||
                    if self.args.nms and (onnx or engine):
 | 
			
		||||
                        m.end2end = True
 | 
			
		||||
                        m.iou_thres = self.args.iou if self.args.iou else 0.65
 | 
			
		||||
                        m.conf_thres = self.args.conf if self.args.conf else 0.25
 | 
			
		||||
 | 
			
		||||
            elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
 | 
			
		||||
                # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
 | 
			
		||||
@ -271,6 +275,10 @@ class Exporter:
 | 
			
		||||
            "batch": self.args.batch,
 | 
			
		||||
            "imgsz": self.imgsz,
 | 
			
		||||
            "names": model.names,
 | 
			
		||||
            "nms": int(self.args.nms),  # json fails if store as bool value
 | 
			
		||||
            "max_det": self.args.max_det,
 | 
			
		||||
            "conf": self.args.conf if self.args.conf else 0.25,
 | 
			
		||||
            "iou": self.args.iou if self.args.iou else 0.65,
 | 
			
		||||
        }  # model metadata
 | 
			
		||||
        if model.task == "pose":
 | 
			
		||||
            self.metadata["kpt_shape"] = model.model[-1].kpt_shape
 | 
			
		||||
@ -366,6 +374,16 @@ class Exporter:
 | 
			
		||||
        f = str(self.file.with_suffix(".onnx"))
 | 
			
		||||
 | 
			
		||||
        output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
 | 
			
		||||
 | 
			
		||||
        if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
			
		||||
            output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
 | 
			
		||||
            shapes = {
 | 
			
		||||
                'num_dets': ["batch" if self.args.dynamic else self.args.batch, 1],
 | 
			
		||||
                'det_boxes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det, 4],
 | 
			
		||||
                'det_scores': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
 | 
			
		||||
                'det_classes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        dynamic = self.args.dynamic
 | 
			
		||||
        if dynamic:
 | 
			
		||||
            dynamic = {"images": {0: "batch", 2: "height", 3: "width"}}  # shape(1,3,640,640)
 | 
			
		||||
@ -374,6 +392,11 @@ class Exporter:
 | 
			
		||||
                dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"}  # shape(1,32,160,160)
 | 
			
		||||
            elif isinstance(self.model, DetectionModel):
 | 
			
		||||
                dynamic["output0"] = {0: "batch", 2: "anchors"}  # shape(1, 84, 8400)
 | 
			
		||||
                if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
			
		||||
                    dynamic["num_dets"] = {0: "batch"} 
 | 
			
		||||
                    dynamic["det_boxes"] = {0: "batch"}
 | 
			
		||||
                    dynamic["det_scores"] = {0: "batch"}
 | 
			
		||||
                    dynamic["det_classes"] = {0: "batch"}
 | 
			
		||||
 | 
			
		||||
        torch.onnx.export(
 | 
			
		||||
            self.model.cpu() if dynamic else self.model,  # dynamic=True only compatible with cpu
 | 
			
		||||
@ -391,6 +414,12 @@ class Exporter:
 | 
			
		||||
        model_onnx = onnx.load(f)  # load onnx model
 | 
			
		||||
        # onnx.checker.check_model(model_onnx)  # check onnx model
 | 
			
		||||
 | 
			
		||||
        if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
			
		||||
            for output in model_onnx.graph.output:
 | 
			
		||||
                for idx, dim in enumerate(output.type.tensor_type.shape.dim):
 | 
			
		||||
                    if output.name in shapes and len(shapes[output.name]):
 | 
			
		||||
                        dim.dim_param = str(shapes[output.name][idx])
 | 
			
		||||
 | 
			
		||||
        # Simplify
 | 
			
		||||
        if self.args.simplify:
 | 
			
		||||
            try:
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,12 @@
 | 
			
		||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
			
		||||
"""Block modules."""
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch import Graph, Tensor, Value
 | 
			
		||||
 | 
			
		||||
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
 | 
			
		||||
from .transformer import TransformerBlock
 | 
			
		||||
@ -825,3 +828,54 @@ class SCDown(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.cv2(self.cv1(x))
 | 
			
		||||
 | 
			
		||||
class Efficient_TRT_NMS(torch.autograd.Function):
 | 
			
		||||
    """NMS block for YOLO-fused model for TensorRT."""
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx: Graph,
 | 
			
		||||
        boxes: Tensor,
 | 
			
		||||
        scores: Tensor,
 | 
			
		||||
        iou_threshold: float = 0.65,
 | 
			
		||||
        score_threshold: float = 0.25,
 | 
			
		||||
        max_output_boxes: int = 100,
 | 
			
		||||
        background_class: int = -1,
 | 
			
		||||
        box_coding: int = 0,
 | 
			
		||||
        plugin_version: str = "1",
 | 
			
		||||
        score_activation: int = 0,
 | 
			
		||||
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
 | 
			
		||||
        batch_size, num_boxes, num_classes = scores.shape
 | 
			
		||||
        num_dets = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
 | 
			
		||||
        boxes = torch.randn(batch_size, max_output_boxes, 4)
 | 
			
		||||
        scores = torch.randn(batch_size, max_output_boxes)
 | 
			
		||||
        labels = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
        return num_dets, boxes, scores, labels
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def symbolic(
 | 
			
		||||
        g,
 | 
			
		||||
        boxes: Value,
 | 
			
		||||
        scores: Value,
 | 
			
		||||
        iou_threshold: float = 0.65,
 | 
			
		||||
        score_threshold: float = 0.25,
 | 
			
		||||
        max_output_boxes: int = 100,
 | 
			
		||||
        background_class: int = -1,
 | 
			
		||||
        box_coding: int = 0,
 | 
			
		||||
        plugin_version: str = "1",
 | 
			
		||||
        score_activation: int = 0,
 | 
			
		||||
    ) -> Tuple[Value, Value, Value, Value]:
 | 
			
		||||
        return g.op(
 | 
			
		||||
            "TRT::EfficientNMS_TRT",
 | 
			
		||||
            boxes,
 | 
			
		||||
            scores,
 | 
			
		||||
            iou_threshold_f=iou_threshold,
 | 
			
		||||
            score_threshold_f=score_threshold,
 | 
			
		||||
            max_output_boxes_i=max_output_boxes,
 | 
			
		||||
            background_class_i=background_class,
 | 
			
		||||
            box_coding_i=box_coding,
 | 
			
		||||
            plugin_version_s=plugin_version,
 | 
			
		||||
            score_activation_i=score_activation,
 | 
			
		||||
            outputs=4,
 | 
			
		||||
        )
 | 
			
		||||
@ -8,7 +8,7 @@ import torch.nn as nn
 | 
			
		||||
from torch.nn.init import constant_, xavier_uniform_
 | 
			
		||||
 | 
			
		||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
 | 
			
		||||
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
 | 
			
		||||
from .block import DFL, Proto, Efficient_TRT_NMS, ContrastiveHead, BNContrastiveHead
 | 
			
		||||
from .conv import Conv
 | 
			
		||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
 | 
			
		||||
from .utils import bias_init_with_prob, linear_init
 | 
			
		||||
@ -496,7 +496,10 @@ class RTDETRDecoder(nn.Module):
 | 
			
		||||
 | 
			
		||||
class v10Detect(Detect):
 | 
			
		||||
 | 
			
		||||
    end2end = False
 | 
			
		||||
    max_det = -1
 | 
			
		||||
    iou_thres = 0.65
 | 
			
		||||
    conf_thres = 0.25
 | 
			
		||||
 | 
			
		||||
    def __init__(self, nc=80, ch=()):
 | 
			
		||||
        super().__init__(nc, ch)
 | 
			
		||||
@ -519,6 +522,12 @@ class v10Detect(Detect):
 | 
			
		||||
                return {"one2many": one2many, "one2one": one2one}
 | 
			
		||||
            else:
 | 
			
		||||
                assert(self.max_det != -1)
 | 
			
		||||
                if self.end2end:
 | 
			
		||||
                    preds = one2one.permute(0, 2, 1)
 | 
			
		||||
                    assert(4 + self.nc == preds.shape[-1])
 | 
			
		||||
                    boxes, scores = preds.split([4, self.nc], dim=-1)
 | 
			
		||||
                    return Efficient_TRT_NMS.apply(boxes, scores, self.iou_thres, self.conf_thres, self.max_det)
 | 
			
		||||
                else:
 | 
			
		||||
                    boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
 | 
			
		||||
                    return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user