diff --git a/tests/test_layers.py b/tests/test_layers.py index 7726c3d05..15813d46f 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d +from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2 import importlib import os @@ -121,6 +121,23 @@ def test_get_act_fn_none(): assert get_act_fn('') is None +@pytest.mark.parametrize("dim", [128]) +@pytest.mark.parametrize("dim_out", [128, 256]) +@pytest.mark.parametrize("use_m", [True, False]) +def test_mqa_v2(dim, dim_out, use_m): + mqa = MultiQueryAttentionV2(dim, dim_out) + + x = torch.randn(1, dim, 32, 48) + if use_m: + m = torch.randn(1, dim, 16, 24) + else: + m = None + + y = mqa(x, m=m) + + assert (y.shape) == (1, dim_out, 32, 48) + + @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("expand_first", [True, False]) @pytest.mark.parametrize("head_first", [True, False]) @@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask): o2 = attn(x, mask) assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}" - - - \ No newline at end of file diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 1b4c65842..d27e7ebd1 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -59,8 +59,8 @@ def _reshape_input(self, t): def forward(self, x, m: Optional[torch.Tensor] = None): """Run layer computation.""" - s = x.shape - m = m or x + b, _, h, w = x.shape + m = m if m is not None else x reshaped_x = self._reshape_input(x) reshaped_m = self._reshape_input(m) @@ -68,15 +68,15 @@ def forward(self, x, m: Optional[torch.Tensor] = None): q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj) k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj) - attn = torch.einsum('bnhk,bmk->bnhm', q, k) + attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj) o = torch.einsum('bnhm,bmv->bnhv', attn, v) - result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj) + result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj) result = self.proj_drop(result) - return result.reshape(s) + return result.reshape(b, -1, h, w) class MultiQueryAttention2d(nn.Module):