From 6171e756d33b170b2cc2f55704066fc5e53edfd3 Mon Sep 17 00:00:00 2001 From: Louis Lac Date: Wed, 1 Jan 2025 15:37:28 +0100 Subject: [PATCH] Fix MQA V2 scale and out shape --- tests/test_layers.py | 18 +++++++++++++++++- timm/layers/attention2d.py | 10 +++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 2cc8420abf..53aeb0a912 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,7 +1,8 @@ +import pytest import torch import torch.nn as nn -from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn +from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, MultiQueryAttentionV2 import importlib import os @@ -119,3 +120,18 @@ def test_get_act_fn_none(): assert get_act_fn(None) is None assert get_act_fn('') is None +@pytest.mark.parametrize("dim", [128]) +@pytest.mark.parametrize("dim_out", [128, 256]) +@pytest.mark.parametrize("use_m", [True, False]) +def test_mqa_v2(dim, dim_out, use_m): + mqa = MultiQueryAttentionV2(dim, dim_out) + + x = torch.randn(1, dim, 32, 48) + if use_m: + m = torch.randn(1, dim, 16, 24) + else: + m = None + + y = mqa(x, m=m) + + assert (y.shape) == (1, dim_out, 32, 48) \ No newline at end of file diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 31200adf76..9ed3cf17da 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -59,8 +59,8 @@ def _reshape_input(self, t): def forward(self, x, m: Optional[torch.Tensor] = None): """Run layer computation.""" - s = x.shape - m = m or x + b, _, h, w = x.shape + m = m if m is not None else x reshaped_x = self._reshape_input(x) reshaped_m = self._reshape_input(m) @@ -68,15 +68,15 @@ def forward(self, x, m: Optional[torch.Tensor] = None): q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj) k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj) - attn = torch.einsum('bnhk,bmk->bnhm', q, k) + attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj) o = torch.einsum('bnhm,bmv->bnhv', attn, v) - result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj) + result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj) result = self.proj_drop(result) - return result.reshape(s) + return result.reshape(b, -1, h, w) class MultiQueryAttention2d(nn.Module):