mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-22 21:04:21 +08:00
fix export tensorrt with dynamic size
This commit is contained in:
parent
9f73bc7768
commit
62aaafb3a0
@ -783,16 +783,16 @@ class Attention(nn.Module):
|
||||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
B, C, H, W = x.shape
|
||||
N = H * W
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
|
||||
q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
|
||||
|
||||
attn = (
|
||||
(q.transpose(-2, -1) @ k) * self.scale
|
||||
)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W))
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user