mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add Conv2() module (#2820)
This commit is contained in:
parent
d19c5b6ce8
commit
441e67d330
@ -48,16 +48,17 @@ trainer.train()
|
|||||||
|
|
||||||
You now realize that you need to customize the trainer further to:
|
You now realize that you need to customize the trainer further to:
|
||||||
|
|
||||||
* * Customize the `loss function`.
|
* Customize the `loss function`.
|
||||||
* Add `callback` that uploads model to your Google Drive after every 10 `epochs`
|
* Add `callback` that uploads model to your Google Drive after every 10 `epochs`
|
||||||
Here's how you can do it:
|
Here's how you can do it:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ultralytics.yolo.v8.detect import DetectionTrainer
|
from ultralytics.yolo.v8.detect import DetectionTrainer
|
||||||
from ultralytcs.nn.tasks import DetectionModel
|
from ultralytics.nn.tasks import DetectionModel
|
||||||
|
|
||||||
|
|
||||||
class MyCustomModel(DetectionModel):
|
class MyCustomModel(DetectionModel):
|
||||||
def init_criterion():
|
def init_criterion(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@ -65,6 +66,7 @@ class CustomTrainer(DetectionTrainer):
|
|||||||
def get_model(self, cfg, weights):
|
def get_model(self, cfg, weights):
|
||||||
return MyCustomModel(...)
|
return MyCustomModel(...)
|
||||||
|
|
||||||
|
|
||||||
# callback to upload model weights
|
# callback to upload model weights
|
||||||
def log_model(trainer):
|
def log_model(trainer):
|
||||||
last_weight_path = trainer.last
|
last_weight_path = trainer.last
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
act: nn.ReLU()
|
act: nn.ReLU()
|
||||||
@ -23,29 +23,31 @@ backbone:
|
|||||||
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
- [-1, 18, Conv, [512, 3, 1]]
|
- [-1, 18, Conv, [512, 3, 1]]
|
||||||
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
- [ -1, 9, Conv, [ 1024, 3, 1 ] ]
|
- [-1, 6, Conv, [1024, 3, 1]]
|
||||||
- [-1, 1, SPPF, [1024, 5]] # 9
|
- [-1, 1, SPPF, [1024, 5]] # 9
|
||||||
|
|
||||||
# YOLOv6-3.0s head
|
# YOLOv6-3.0s head
|
||||||
head:
|
head:
|
||||||
|
- [-1, 1, Conv, [256, 1, 1]]
|
||||||
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
|
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
|
||||||
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
||||||
- [-1, 1, Conv, [256, 3, 1]]
|
- [-1, 1, Conv, [256, 3, 1]]
|
||||||
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 13
|
- [-1, 9, Conv, [256, 3, 1]] # 14
|
||||||
|
|
||||||
|
- [-1, 1, Conv, [128, 1, 1]]
|
||||||
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
|
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
|
||||||
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
||||||
- [-1, 1, Conv, [128, 3, 1]]
|
- [-1, 1, Conv, [128, 3, 1]]
|
||||||
- [ -1, 9, Conv, [ 128, 3, 1 ] ] # 17
|
- [-1, 9, Conv, [128, 3, 1]] # 19
|
||||||
|
|
||||||
- [-1, 1, Conv, [128, 3, 2]]
|
- [-1, 1, Conv, [128, 3, 2]]
|
||||||
- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4
|
- [[-1, 15], 1, Concat, [1]] # cat head P4
|
||||||
- [-1, 1, Conv, [256, 3, 1]]
|
- [-1, 1, Conv, [256, 3, 1]]
|
||||||
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 21
|
- [-1, 9, Conv, [256, 3, 1]] # 23
|
||||||
|
|
||||||
- [-1, 1, Conv, [256, 3, 2]]
|
- [-1, 1, Conv, [256, 3, 2]]
|
||||||
- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5
|
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
||||||
- [-1, 1, Conv, [512, 3, 1]]
|
- [-1, 1, Conv, [512, 3, 1]]
|
||||||
- [ -1, 9, Conv, [ 512, 3, 1 ] ] # 25
|
- [-1, 9, Conv, [512, 3, 1]] # 27
|
||||||
|
|
||||||
- [ [ 17, 21, 25 ], 1, Detect, [ nc ] ] # Detect(P3, P4, P5)
|
- [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 1 # number of classes
|
nc: 1 # number of classes
|
||||||
|
@ -1,15 +1,28 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Ultralytics modules. Visualize with:
|
||||||
|
|
||||||
|
from ultralytics.nn.modules import *
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
x = torch.ones(1, 128, 40, 40)
|
||||||
|
m = Conv(128, 128)
|
||||||
|
f = f'{m._get_name()}.onnx'
|
||||||
|
torch.onnx.export(m, x, f)
|
||||||
|
os.system(f'onnxsim {f} {f} && open {f}')
|
||||||
|
"""
|
||||||
|
|
||||||
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
|
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
|
||||||
HGBlock, HGStem, Proto, RepC3)
|
HGBlock, HGStem, Proto, RepC3)
|
||||||
from .conv import (CBAM, ChannelAttention, Concat, Conv, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv,
|
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
|
||||||
LightConv, RepConv, SpatialAttention)
|
GhostConv, LightConv, RepConv, SpatialAttention)
|
||||||
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
|
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
|
||||||
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
|
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
|
||||||
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
|
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Conv', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
'Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
||||||
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock',
|
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock',
|
||||||
'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
|
'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
|
||||||
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify',
|
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify',
|
||||||
|
@ -43,6 +43,27 @@ class Conv(nn.Module):
|
|||||||
return self.act(self.conv(x))
|
return self.act(self.conv(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2(Conv):
|
||||||
|
"""Simplified RepConv module with Conv fusing."""
|
||||||
|
|
||||||
|
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
|
||||||
|
"""Initialize Conv layer with given arguments including activation."""
|
||||||
|
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
|
||||||
|
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply convolution, batch normalization and activation to input tensor."""
|
||||||
|
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
||||||
|
|
||||||
|
def fuse_convs(self):
|
||||||
|
"""Fuse parallel convolutions."""
|
||||||
|
w = torch.zeros_like(self.conv.weight.data)
|
||||||
|
i = [x // 2 for x in w.shape[2:]]
|
||||||
|
w[:, :, i[0] - 1:i[0], i[1] - 1:i[1]] = self.cv2.weight.data.clone()
|
||||||
|
self.conv.weight.data += w
|
||||||
|
self.__delattr__('cv2')
|
||||||
|
|
||||||
|
|
||||||
class LightConv(nn.Module):
|
class LightConv(nn.Module):
|
||||||
"""Light convolution with args(ch_in, ch_out, kernel).
|
"""Light convolution with args(ch_in, ch_out, kernel).
|
||||||
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
||||||
|
@ -8,9 +8,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
||||||
Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
|
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
||||||
GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
|
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
||||||
Segment)
|
RTDETRDecoder, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||||
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||||
@ -103,7 +103,9 @@ class BaseModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if not self.is_fused():
|
if not self.is_fused():
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
||||||
|
if isinstance(m, Conv2):
|
||||||
|
m.fuse_convs()
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, 'bn') # remove batchnorm
|
||||||
m.forward = m.forward_fuse # update forward
|
m.forward = m.forward_fuse # update forward
|
||||||
|
Loading…
x
Reference in New Issue
Block a user