diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index b76a9595..d11c16e0 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -783,16 +783,16 @@ class Attention(nn.Module): self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) def forward(self, x): - B, _, H, W = x.shape + B, C, H, W = x.shape N = H * W qkv = self.qkv(x) - q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2) + q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2) attn = ( (q.transpose(-2, -1) @ k) * self.scale ) attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W)) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) x = self.proj(x) return x