From a039526fcc650bdfcda37ed3de537a10945fc8c7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 17 Nov 2024 16:52:51 +0800 Subject: [PATCH] Update LayerNorm --- .../encoders/mix_transformer.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 9ae4fcd2..0cc3fb21 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -21,14 +21,10 @@ def forward(self, x): if x.ndim == 4: B, C, H, W = x.shape x = x.view(B, C, -1).transpose(1, 2) - x = nn.functional.layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - ) - x = x.transpose(1, 2).view(B, -1, H, W).contiguous() + x = super().forward(x) + x = x.transpose(1, 2).view(B, C, H, W) else: - x = nn.functional.layer_norm( - x, self.normalized_shape, self.weight, self.bias, self.eps - ) + x = super().forward(x) return x @@ -472,25 +468,25 @@ def forward_features(self, x): # stage 1 x = self.patch_embed1(x) x = self.block1(x) - x = self.norm1(x) + x = self.norm1(x).contiguous() outs.append(x) # stage 2 x = self.patch_embed2(x) x = self.block2(x) - x = self.norm2(x) + x = self.norm2(x).contiguous() outs.append(x) # stage 3 x = self.patch_embed3(x) x = self.block3(x) - x = self.norm3(x) + x = self.norm3(x).contiguous() outs.append(x) # stage 4 x = self.patch_embed4(x) x = self.block4(x) - x = self.norm4(x) + x = self.norm4(x).contiguous() outs.append(x) return outs @@ -552,7 +548,7 @@ def forward(self, x): if i == 1: features.append(dummy) else: - x = stages[i](x) + x = stages[i](x).contiguous() features.append(x) return features