forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path(CVPR 2024)SHSA.py
86 lines (67 loc) · 2.89 KB
/
(CVPR 2024)SHSA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
# 论文:SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design, CVPR 2024
# 论文地址:https://arxiv.org/pdf/2401.16456
# Github地址:https://github.com/ysj9909/SHViT
# 全网最全100➕即插即用模块GitHub地址:https://github.com/ai-dawang/PlugNPlay-Modules
class GroupNorm(torch.nn.GroupNorm):
"""
Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class SHSA(torch.nn.Module):
"""Single-Head Self-Attention"""
def __init__(self, dim, qk_dim=16, pdim=32):
super().__init__()
self.scale = qk_dim ** -0.5
self.qk_dim = qk_dim
self.dim = dim
self.pdim = pdim
self.pre_norm = GroupNorm(pdim)
self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
dim, dim, bn_weight_init=0))
def forward(self, x):
B, C, H, W = x.shape
x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim=1)
x1 = self.pre_norm(x1)
qkv = self.qkv(x1)
q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim=1)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
x = self.proj(torch.cat([x1, x2], dim=1))
return x
if __name__ == '__main__':
block = SHSA(64) #输入 C
input = torch.randn(1, 64, 32, 32) # 输入 B C H W
# Print input shape
print(input.size())
# Forward pass through the SHSA module
output = block(input)
# Print output shape
print(output.size())