forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDynamicFilter(频域_CV2维图像).py
111 lines (97 loc) · 4.34 KB
/
DynamicFilter(频域_CV2维图像).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
from timm.layers.helpers import to_2tuple
# 论文:FFT-based Dynamic Token Mixer for Vision
# 论文地址:https://arxiv.org/pdf/2303.03932
# 全网最全100➕即插即用模块GitHub地址:https://github.com/ai-dawang/PlugNPlay-Modules
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(self, scale_value=1.0, bias_value=0.0,
scale_learnable=True, bias_learnable=True,
mode=None, inplace=False):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(scale_value * torch.ones(1),
requires_grad=scale_learnable)
self.bias = nn.Parameter(bias_value * torch.ones(1),
requires_grad=bias_learnable)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module):
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
Mostly copied from timm.
"""
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
bias=False, **kwargs):
super().__init__()
in_features = dim
out_features = out_features or in_features
hidden_features = int(mlp_ratio * in_features)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class DynamicFilter(nn.Module):
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
act1_layer=StarReLU, act2_layer=nn.Identity,
bias=False, num_filters=4, size=14, weight_resize=False,
**kwargs):
super().__init__()
size = to_2tuple(size)
self.size = size[0]
self.filter_size = size[1] // 2 + 1
self.num_filters = num_filters
self.dim = dim
self.med_channels = int(expansion_ratio * dim)
self.weight_resize = weight_resize
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
self.act1 = act1_layer()
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
self.complex_weights = nn.Parameter(
torch.randn(self.size, self.filter_size, num_filters, 2,
dtype=torch.float32) * 0.02)
self.act2 = act2_layer()
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
def forward(self, x):
B, H, W, _ = x.shape
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
-1).softmax(dim=1)
x = self.pwconv1(x)
x = self.act1(x)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
if self.weight_resize:
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
x.shape[2])
complex_weights = torch.view_as_complex(complex_weights.contiguous())
else:
complex_weights = torch.view_as_complex(self.complex_weights)
routeing = routeing.to(torch.complex64)
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
if self.weight_resize:
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
else:
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
x = self.act2(x)
x = self.pwconv2(x)
return x
if __name__ == '__main__':
block = DynamicFilter(32, size=64) # size==H,W
input = torch.rand(3, 64, 64, 32) #输入 B C H W
output = block(input)
print(input.size())
print(output.size())