-
I'm looking to implement InstanceNorm, and I was looking at the LayerNorm implementation (https://flax.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#LayerNorm). In CNNs, Flax accepts input in the NHWC format (from inspecting the MNIST example). LayerNorm is defined as so in the paper (https://arxiv.org/abs/1607.06450) as computing mean and variance over all channels (C) and spatial dimensions (HW). But the implementation computes mean and variance over only the last axis: x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
var = mean2 - lax.square(mean) Why don't we compute it as so? mean = jnp.mean(x, axis=(-1, -2, -3), keepdims=True) # -1:-(len(x.shape)-1) Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I think we should be a little careful here to avoid reducing over axes that actually aren't spatial dims. To me it would make most sense if axis is an constructor argument with default value |
Beta Was this translation helpful? Give feedback.
I think we should be a little careful here to avoid reducing over axes that actually aren't spatial dims. To me it would make most sense if axis is an constructor argument with default value
-1
. Feel free to file a PR or issue for this.