mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Support fuse-deconv-and-bn (#786)
This commit is contained in:
parent
fa8811dcee
commit
5a80ad98db
@ -62,6 +62,9 @@ class ConvTranspose(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.act(self.bn(self.conv_transpose(x)))
|
return self.act(self.bn(self.conv_transpose(x)))
|
||||||
|
|
||||||
|
def forward_fuse(self, x):
|
||||||
|
return self.act(self.conv_transpose(x))
|
||||||
|
|
||||||
|
|
||||||
class DFL(nn.Module):
|
class DFL(nn.Module):
|
||||||
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||||
|
@ -12,8 +12,8 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot
|
|||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||||
model_info, scale_img, time_sync)
|
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
@ -100,6 +100,10 @@ class BaseModel(nn.Module):
|
|||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, 'bn') # remove batchnorm
|
||||||
m.forward = m.forward_fuse # update forward
|
m.forward = m.forward_fuse # update forward
|
||||||
|
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
|
||||||
|
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
||||||
|
delattr(m, 'bn') # remove batchnorm
|
||||||
|
m.forward = m.forward_fuse # update forward
|
||||||
self.info()
|
self.info()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -135,6 +135,30 @@ def fuse_conv_and_bn(conv, bn):
|
|||||||
return fusedconv
|
return fusedconv
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_deconv_and_bn(deconv, bn):
|
||||||
|
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||||||
|
deconv.out_channels,
|
||||||
|
kernel_size=deconv.kernel_size,
|
||||||
|
stride=deconv.stride,
|
||||||
|
padding=deconv.padding,
|
||||||
|
output_padding=deconv.output_padding,
|
||||||
|
dilation=deconv.dilation,
|
||||||
|
groups=deconv.groups,
|
||||||
|
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||||||
|
|
||||||
|
# prepare filters
|
||||||
|
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||||
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||||
|
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||||
|
|
||||||
|
# Prepare spatial bias
|
||||||
|
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||||||
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||||
|
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||||
|
|
||||||
|
return fuseddconv
|
||||||
|
|
||||||
|
|
||||||
def model_info(model, verbose=False, imgsz=640):
|
def model_info(model, verbose=False, imgsz=640):
|
||||||
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
||||||
n_p = get_num_params(model)
|
n_p = get_num_params(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user