mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
update
This commit is contained in:
parent
182f0b09a9
commit
1e9a7d6736
18
README.md
18
README.md
@ -58,15 +58,23 @@ yolo predict model=yolov10n/s/m/b/l/x.pt
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Export
|
## Export
|
||||||
|
```
|
||||||
|
# End-to-End ONNX
|
||||||
|
yolo export model=yolov10n/s/m/b/l/x.pt format=onnx opset=13 simplify
|
||||||
|
# Predict with ONNX
|
||||||
|
yolo predict model=yolov10n/s/m/b/l/x.onnx
|
||||||
|
|
||||||
|
# End-to-End TensorRT
|
||||||
## Latency Measurement
|
yolo export model=yolov10n/s/m/b/l/x.pt format=engine half=True simplify opset=13 workspace=16
|
||||||
|
# Or
|
||||||
|
trtexec --onnx=onnxs/yolov10n/s/m/b/l/x.onnx --saveEngine=engines/yolov10n/s/m/b/l/x.engine --fp16
|
||||||
|
# Predict with TensorRT
|
||||||
|
yolo predict model=yolov10n/s/m/b/l/x.engine
|
||||||
|
```
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
The code base is built with [ultralytics](https://github.com/ultralytics/ultralytics)
|
The code base is built with [ultralytics](https://github.com/ultralytics/ultralytics) and [RT-DETR](https://github.com/lyuwenyu/RT-DETR)
|
||||||
|
|
||||||
Thanks for the great implementations!
|
Thanks for the great implementations!
|
||||||
|
|
||||||
|
@ -5,3 +5,5 @@ onnxruntime
|
|||||||
pycocotools
|
pycocotools
|
||||||
PyYAML
|
PyYAML
|
||||||
scipy
|
scipy
|
||||||
|
onnxsim
|
||||||
|
onnxruntime-gpu
|
@ -67,7 +67,7 @@ from ultralytics.cfg import get_cfg
|
|||||||
from ultralytics.data.dataset import YOLODataset
|
from ultralytics.data.dataset import YOLODataset
|
||||||
from ultralytics.data.utils import check_det_dataset
|
from ultralytics.data.utils import check_det_dataset
|
||||||
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
from ultralytics.nn.autobackend import check_class_names, default_class_names
|
||||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect
|
||||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
ARM64,
|
ARM64,
|
||||||
@ -229,6 +229,9 @@ class Exporter:
|
|||||||
m.dynamic = self.args.dynamic
|
m.dynamic = self.args.dynamic
|
||||||
m.export = True
|
m.export = True
|
||||||
m.format = self.args.format
|
m.format = self.args.format
|
||||||
|
if isinstance(m, v10Detect):
|
||||||
|
m.max_det = self.args.max_det
|
||||||
|
|
||||||
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
||||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||||
m.forward = m.forward_split
|
m.forward = m.forward_split
|
||||||
|
@ -9,10 +9,13 @@ class YOLOv10DetectionPredictor(DetectionPredictor):
|
|||||||
if isinstance(preds, (list, tuple)):
|
if isinstance(preds, (list, tuple)):
|
||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
|
|
||||||
preds = preds.transpose(-1, -2)
|
if preds.shape[-1] == 6:
|
||||||
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
|
pass
|
||||||
bboxes = ops.xywh2xyxy(bboxes)
|
else:
|
||||||
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
preds = preds.transpose(-1, -2)
|
||||||
|
bboxes, scores, labels = ops.v10postprocess(preds, self.args.max_det)
|
||||||
|
bboxes = ops.xywh2xyxy(bboxes)
|
||||||
|
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
mask = preds[..., 4] > self.args.conf
|
mask = preds[..., 4] > self.args.conf
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from .conv import Conv
|
|||||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||||
from .utils import bias_init_with_prob, linear_init
|
from .utils import bias_init_with_prob, linear_init
|
||||||
import copy
|
import copy
|
||||||
|
from ultralytics.utils import ops
|
||||||
|
|
||||||
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
|
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
|
||||||
|
|
||||||
@ -51,7 +52,6 @@ class Detect(nn.Module):
|
|||||||
shape = x[0].shape # BCHW
|
shape = x[0].shape # BCHW
|
||||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||||
if self.dynamic or self.shape != shape:
|
if self.dynamic or self.shape != shape:
|
||||||
assert(not self.export)
|
|
||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
@ -501,6 +501,8 @@ class RTDETRDecoder(nn.Module):
|
|||||||
|
|
||||||
class v10Detect(Detect):
|
class v10Detect(Detect):
|
||||||
|
|
||||||
|
max_det = -1
|
||||||
|
|
||||||
def __init__(self, nc=80, ch=()):
|
def __init__(self, nc=80, ch=()):
|
||||||
super().__init__(nc, ch)
|
super().__init__(nc, ch)
|
||||||
c3 = max(ch[0], min(self.nc, 100)) # channels
|
c3 = max(ch[0], min(self.nc, 100)) # channels
|
||||||
@ -515,7 +517,12 @@ class v10Detect(Detect):
|
|||||||
one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3)
|
one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
one2one = self.inference(one2one)
|
one2one = self.inference(one2one)
|
||||||
return one2one
|
if not self.export:
|
||||||
|
return one2one
|
||||||
|
else:
|
||||||
|
assert(self.max_det != -1)
|
||||||
|
boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det)
|
||||||
|
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||||
else:
|
else:
|
||||||
one2many = super().forward(x)
|
one2many = super().forward(x)
|
||||||
return {"one2many": one2many, "one2one": one2one}
|
return {"one2many": one2many, "one2one": one2one}
|
||||||
|
@ -848,8 +848,8 @@ def clean_str(s):
|
|||||||
"""
|
"""
|
||||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||||
|
|
||||||
def v10postprocess(preds, max_det):
|
def v10postprocess(preds, max_det, nc=80):
|
||||||
nc = preds.shape[-1] - 4
|
assert(4 + nc == preds.shape[-1])
|
||||||
boxes, scores = preds.split([4, nc], dim=-1)
|
boxes, scores = preds.split([4, nc], dim=-1)
|
||||||
max_scores = scores.amax(dim=-1)
|
max_scores = scores.amax(dim=-1)
|
||||||
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
|
max_scores, index = torch.topk(max_scores, max_det, axis=-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user