diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index d1e6774cc..188c6b530 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -12,6 +12,8 @@ class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, @@ -51,6 +53,8 @@ def forward(self, x): class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, @@ -192,7 +196,7 @@ def forward(self, x): class ConvMlp(nn.Module): - """ MLP using 1x1 convs that keeps spatial dims + """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors) """ def __init__( self, @@ -226,6 +230,8 @@ def forward(self, x): class GlobalResponseNormMlp(nn.Module): """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d + + NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts """ def __init__( self,