mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 00:45:38 +08:00 
			
		
		
		
	Added Support ONNX End2End (TRT EfficientNMS)
This commit is contained in:
		
							parent
							
								
									74f62fb017
								
							
						
					
					
						commit
						fd013f5442
					
				@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
 | 
				
			|||||||
from ultralytics.data.utils import check_det_dataset
 | 
					from ultralytics.data.utils import check_det_dataset
 | 
				
			||||||
from ultralytics.nn.autobackend import check_class_names, default_class_names
 | 
					from ultralytics.nn.autobackend import check_class_names, default_class_names
 | 
				
			||||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect
 | 
					from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect
 | 
				
			||||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
 | 
					from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel, YOLOv10DetectionModel
 | 
				
			||||||
from ultralytics.utils import (
 | 
					from ultralytics.utils import (
 | 
				
			||||||
    ARM64,
 | 
					    ARM64,
 | 
				
			||||||
    DEFAULT_CFG,
 | 
					    DEFAULT_CFG,
 | 
				
			||||||
@ -231,6 +231,10 @@ class Exporter:
 | 
				
			|||||||
                m.format = self.args.format
 | 
					                m.format = self.args.format
 | 
				
			||||||
                if isinstance(m, v10Detect):
 | 
					                if isinstance(m, v10Detect):
 | 
				
			||||||
                    m.max_det = self.args.max_det
 | 
					                    m.max_det = self.args.max_det
 | 
				
			||||||
 | 
					                    if self.args.nms and (onnx or engine):
 | 
				
			||||||
 | 
					                        m.end2end = True
 | 
				
			||||||
 | 
					                        m.iou_thres = self.args.iou if self.args.iou else 0.65
 | 
				
			||||||
 | 
					                        m.conf_thres = self.args.conf if self.args.conf else 0.25
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
 | 
					            elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
 | 
				
			||||||
                # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
 | 
					                # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
 | 
				
			||||||
@ -271,6 +275,10 @@ class Exporter:
 | 
				
			|||||||
            "batch": self.args.batch,
 | 
					            "batch": self.args.batch,
 | 
				
			||||||
            "imgsz": self.imgsz,
 | 
					            "imgsz": self.imgsz,
 | 
				
			||||||
            "names": model.names,
 | 
					            "names": model.names,
 | 
				
			||||||
 | 
					            "nms": int(self.args.nms),  # json fails if store as bool value
 | 
				
			||||||
 | 
					            "max_det": self.args.max_det,
 | 
				
			||||||
 | 
					            "conf": self.args.conf if self.args.conf else 0.25,
 | 
				
			||||||
 | 
					            "iou": self.args.iou if self.args.iou else 0.65,
 | 
				
			||||||
        }  # model metadata
 | 
					        }  # model metadata
 | 
				
			||||||
        if model.task == "pose":
 | 
					        if model.task == "pose":
 | 
				
			||||||
            self.metadata["kpt_shape"] = model.model[-1].kpt_shape
 | 
					            self.metadata["kpt_shape"] = model.model[-1].kpt_shape
 | 
				
			||||||
@ -366,6 +374,16 @@ class Exporter:
 | 
				
			|||||||
        f = str(self.file.with_suffix(".onnx"))
 | 
					        f = str(self.file.with_suffix(".onnx"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
 | 
					        output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
				
			||||||
 | 
					            output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
 | 
				
			||||||
 | 
					            shapes = {
 | 
				
			||||||
 | 
					                'num_dets': ["batch" if self.args.dynamic else self.args.batch, 1],
 | 
				
			||||||
 | 
					                'det_boxes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det, 4],
 | 
				
			||||||
 | 
					                'det_scores': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
 | 
				
			||||||
 | 
					                'det_classes': ["batch" if self.args.dynamic else self.args.batch, self.args.max_det],
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dynamic = self.args.dynamic
 | 
					        dynamic = self.args.dynamic
 | 
				
			||||||
        if dynamic:
 | 
					        if dynamic:
 | 
				
			||||||
            dynamic = {"images": {0: "batch", 2: "height", 3: "width"}}  # shape(1,3,640,640)
 | 
					            dynamic = {"images": {0: "batch", 2: "height", 3: "width"}}  # shape(1,3,640,640)
 | 
				
			||||||
@ -374,6 +392,11 @@ class Exporter:
 | 
				
			|||||||
                dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"}  # shape(1,32,160,160)
 | 
					                dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"}  # shape(1,32,160,160)
 | 
				
			||||||
            elif isinstance(self.model, DetectionModel):
 | 
					            elif isinstance(self.model, DetectionModel):
 | 
				
			||||||
                dynamic["output0"] = {0: "batch", 2: "anchors"}  # shape(1, 84, 8400)
 | 
					                dynamic["output0"] = {0: "batch", 2: "anchors"}  # shape(1, 84, 8400)
 | 
				
			||||||
 | 
					                if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
				
			||||||
 | 
					                    dynamic["num_dets"] = {0: "batch"} 
 | 
				
			||||||
 | 
					                    dynamic["det_boxes"] = {0: "batch"}
 | 
				
			||||||
 | 
					                    dynamic["det_scores"] = {0: "batch"}
 | 
				
			||||||
 | 
					                    dynamic["det_classes"] = {0: "batch"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        torch.onnx.export(
 | 
					        torch.onnx.export(
 | 
				
			||||||
            self.model.cpu() if dynamic else self.model,  # dynamic=True only compatible with cpu
 | 
					            self.model.cpu() if dynamic else self.model,  # dynamic=True only compatible with cpu
 | 
				
			||||||
@ -391,6 +414,12 @@ class Exporter:
 | 
				
			|||||||
        model_onnx = onnx.load(f)  # load onnx model
 | 
					        model_onnx = onnx.load(f)  # load onnx model
 | 
				
			||||||
        # onnx.checker.check_model(model_onnx)  # check onnx model
 | 
					        # onnx.checker.check_model(model_onnx)  # check onnx model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.args.nms and isinstance(self.model, YOLOv10DetectionModel):
 | 
				
			||||||
 | 
					            for output in model_onnx.graph.output:
 | 
				
			||||||
 | 
					                for idx, dim in enumerate(output.type.tensor_type.shape.dim):
 | 
				
			||||||
 | 
					                    if output.name in shapes and len(shapes[output.name]):
 | 
				
			||||||
 | 
					                        dim.dim_param = str(shapes[output.name][idx])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Simplify
 | 
					        # Simplify
 | 
				
			||||||
        if self.args.simplify:
 | 
					        if self.args.simplify:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,12 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
				
			||||||
"""Block modules."""
 | 
					"""Block modules."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from torch import Graph, Tensor, Value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
 | 
					from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
 | 
				
			||||||
from .transformer import TransformerBlock
 | 
					from .transformer import TransformerBlock
 | 
				
			||||||
@ -825,3 +828,54 @@ class SCDown(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return self.cv2(self.cv1(x))
 | 
					        return self.cv2(self.cv1(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Efficient_TRT_NMS(torch.autograd.Function):
 | 
				
			||||||
 | 
					    """NMS block for YOLO-fused model for TensorRT."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        ctx: Graph,
 | 
				
			||||||
 | 
					        boxes: Tensor,
 | 
				
			||||||
 | 
					        scores: Tensor,
 | 
				
			||||||
 | 
					        iou_threshold: float = 0.65,
 | 
				
			||||||
 | 
					        score_threshold: float = 0.25,
 | 
				
			||||||
 | 
					        max_output_boxes: int = 100,
 | 
				
			||||||
 | 
					        background_class: int = -1,
 | 
				
			||||||
 | 
					        box_coding: int = 0,
 | 
				
			||||||
 | 
					        plugin_version: str = "1",
 | 
				
			||||||
 | 
					        score_activation: int = 0,
 | 
				
			||||||
 | 
					    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
 | 
				
			||||||
 | 
					        batch_size, num_boxes, num_classes = scores.shape
 | 
				
			||||||
 | 
					        num_dets = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
 | 
				
			||||||
 | 
					        boxes = torch.randn(batch_size, max_output_boxes, 4)
 | 
				
			||||||
 | 
					        scores = torch.randn(batch_size, max_output_boxes)
 | 
				
			||||||
 | 
					        labels = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return num_dets, boxes, scores, labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def symbolic(
 | 
				
			||||||
 | 
					        g,
 | 
				
			||||||
 | 
					        boxes: Value,
 | 
				
			||||||
 | 
					        scores: Value,
 | 
				
			||||||
 | 
					        iou_threshold: float = 0.65,
 | 
				
			||||||
 | 
					        score_threshold: float = 0.25,
 | 
				
			||||||
 | 
					        max_output_boxes: int = 100,
 | 
				
			||||||
 | 
					        background_class: int = -1,
 | 
				
			||||||
 | 
					        box_coding: int = 0,
 | 
				
			||||||
 | 
					        plugin_version: str = "1",
 | 
				
			||||||
 | 
					        score_activation: int = 0,
 | 
				
			||||||
 | 
					    ) -> Tuple[Value, Value, Value, Value]:
 | 
				
			||||||
 | 
					        return g.op(
 | 
				
			||||||
 | 
					            "TRT::EfficientNMS_TRT",
 | 
				
			||||||
 | 
					            boxes,
 | 
				
			||||||
 | 
					            scores,
 | 
				
			||||||
 | 
					            iou_threshold_f=iou_threshold,
 | 
				
			||||||
 | 
					            score_threshold_f=score_threshold,
 | 
				
			||||||
 | 
					            max_output_boxes_i=max_output_boxes,
 | 
				
			||||||
 | 
					            background_class_i=background_class,
 | 
				
			||||||
 | 
					            box_coding_i=box_coding,
 | 
				
			||||||
 | 
					            plugin_version_s=plugin_version,
 | 
				
			||||||
 | 
					            score_activation_i=score_activation,
 | 
				
			||||||
 | 
					            outputs=4,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
@ -8,7 +8,7 @@ import torch.nn as nn
 | 
				
			|||||||
from torch.nn.init import constant_, xavier_uniform_
 | 
					from torch.nn.init import constant_, xavier_uniform_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
 | 
					from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
 | 
				
			||||||
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
 | 
					from .block import DFL, Proto, Efficient_TRT_NMS, ContrastiveHead, BNContrastiveHead
 | 
				
			||||||
from .conv import Conv
 | 
					from .conv import Conv
 | 
				
			||||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
 | 
					from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
 | 
				
			||||||
from .utils import bias_init_with_prob, linear_init
 | 
					from .utils import bias_init_with_prob, linear_init
 | 
				
			||||||
@ -496,7 +496,10 @@ class RTDETRDecoder(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class v10Detect(Detect):
 | 
					class v10Detect(Detect):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    end2end = False
 | 
				
			||||||
    max_det = -1
 | 
					    max_det = -1
 | 
				
			||||||
 | 
					    iou_thres = 0.65
 | 
				
			||||||
 | 
					    conf_thres = 0.25
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, nc=80, ch=()):
 | 
					    def __init__(self, nc=80, ch=()):
 | 
				
			||||||
        super().__init__(nc, ch)
 | 
					        super().__init__(nc, ch)
 | 
				
			||||||
@ -519,8 +522,14 @@ class v10Detect(Detect):
 | 
				
			|||||||
                return {"one2many": one2many, "one2one": one2one}
 | 
					                return {"one2many": one2many, "one2one": one2one}
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                assert(self.max_det != -1)
 | 
					                assert(self.max_det != -1)
 | 
				
			||||||
                boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
 | 
					                if self.end2end:
 | 
				
			||||||
                return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
 | 
					                    preds = one2one.permute(0, 2, 1)
 | 
				
			||||||
 | 
					                    assert(4 + self.nc == preds.shape[-1])
 | 
				
			||||||
 | 
					                    boxes, scores = preds.split([4, self.nc], dim=-1)
 | 
				
			||||||
 | 
					                    return Efficient_TRT_NMS.apply(boxes, scores, self.iou_thres, self.conf_thres, self.max_det)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc)
 | 
				
			||||||
 | 
					                    return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return {"one2many": one2many, "one2one": one2one}
 | 
					            return {"one2many": one2many, "one2one": one2one}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user