mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
fix export tensorrt with dynamic size
This commit is contained in:
parent
9f73bc7768
commit
62aaafb3a0
@ -783,16 +783,16 @@ class Attention(nn.Module):
|
|||||||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, _, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
|
q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
|
||||||
|
|
||||||
attn = (
|
attn = (
|
||||||
(q.transpose(-2, -1) @ k) * self.scale
|
(q.transpose(-2, -1) @ k) * self.scale
|
||||||
)
|
)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W))
|
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user