Skip to content

Commit

Permalink
Update LayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
brianhou0208 committed Nov 29, 2024
1 parent 2989d41 commit a039526
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions segmentation_models_pytorch/encoders/mix_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ def forward(self, x):
if x.ndim == 4:
B, C, H, W = x.shape
x = x.view(B, C, -1).transpose(1, 2)
x = nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
x = x.transpose(1, 2).view(B, -1, H, W).contiguous()
x = super().forward(x)
x = x.transpose(1, 2).view(B, C, H, W)
else:
x = nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
x = super().forward(x)
return x


Expand Down Expand Up @@ -472,25 +468,25 @@ def forward_features(self, x):
# stage 1
x = self.patch_embed1(x)
x = self.block1(x)
x = self.norm1(x)
x = self.norm1(x).contiguous()
outs.append(x)

# stage 2
x = self.patch_embed2(x)
x = self.block2(x)
x = self.norm2(x)
x = self.norm2(x).contiguous()
outs.append(x)

# stage 3
x = self.patch_embed3(x)
x = self.block3(x)
x = self.norm3(x)
x = self.norm3(x).contiguous()
outs.append(x)

# stage 4
x = self.patch_embed4(x)
x = self.block4(x)
x = self.norm4(x)
x = self.norm4(x).contiguous()
outs.append(x)

return outs
Expand Down Expand Up @@ -552,7 +548,7 @@ def forward(self, x):
if i == 1:
features.append(dummy)
else:
x = stages[i](x)
x = stages[i](x).contiguous()
features.append(x)
return features

Expand Down

0 comments on commit a039526

Please sign in to comment.