Masked linear for MADE #1231
Answered
by
ikostrikov
ikostrikov
asked this question in
Q&A
-
Hi All, I'm trying to implement a masked linear layer for Masked Autoregressive Density Estimator (link). The obvious way is to implement it using variables class MaskedDense(nn.Dense):
flow_dim: int = 1
mask_type: MaskType = MaskType.hidden
@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param('kernel', self.kernel_init,
(inputs.shape[-1], self.features))
kernel = jnp.asarray(kernel, self.dtype)
mask = self.variable('mask', 'mask', lambda args: get_mask(*args),
(*kernel.shape, self.flow_dim, self.mask_type))
kernel = kernel * mask.value
y = jax.lax.dot_general(inputs,
kernel,
(((inputs.ndim - 1, ), (0, )), ((), ())),
precision=self.precision)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.features, ))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return y But in this case, I also need to pass the masks in module.apply that is redundant since the masks do not change over training. Is there a more elegant way to implement it in flax? |
Beta Was this translation helpful? Give feedback.
Answered by
ikostrikov
Apr 10, 2021
Replies: 1 comment
-
I found a solution with jax.util.cache: import flax.linen as nn
import jax
import jax.numpy as jnp
import enum
from jax.experimental.host_callback import id_print
class MaskType(enum.Enum):
input = 1
hidden = 2
output = 3
@jax.util.cache()
def get_mask(input_dim: int, output_dim: int, randvar_dim: int,
mask_type: MaskType) -> jnp.DeviceArray:
"""
Create a mask for MADE.
See Figure 1 for a better illustration:
https://arxiv.org/pdf/1502.03509.pdf
Args:
input_dim: Dimensionality of the inputs.
output_dim: Dimensionality of the outputs.
rand_var_dim: Dimensionality of the random variable.
mask_type: MaskType.
Returns:
A mask.
"""
if mask_type == MaskType.input:
in_degrees = jnp.arange(input_dim) % randvar_dim
else:
in_degrees = jnp.arange(input_dim) % (randvar_dim - 1)
if mask_type == MaskType.output:
out_degrees = jnp.arange(output_dim) % randvar_dim - 1
else:
out_degrees = jnp.arange(output_dim) % (randvar_dim - 1)
in_degrees = jnp.expand_dims(in_degrees, 0)
out_degrees = jnp.expand_dims(out_degrees, -1)
return (out_degrees >= in_degrees).astype(jnp.float32).transpose()
class MaskedDense(nn.Dense):
flow_dim: int = 1
mask_type: MaskType = MaskType.hidden
@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param('kernel', self.kernel_init,
(inputs.shape[-1], self.features))
kernel = jnp.asarray(kernel, self.dtype)
mask = get_mask(*kernel.shape, self.flow_dim, self.mask_type)
kernel = kernel * mask
y = jax.lax.dot_general(inputs,
kernel,
(((inputs.ndim - 1, ), (0, )), ((), ())),
precision=self.precision)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.features, ))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return y
model = MaskedDense(features=5, flow_dim=3)
rng = jax.random.PRNGKey(1)
rng, key1, key2 = jax.random.split(rng, 3)
x = jax.random.normal(key1, (32, 2))
y, params = model.init_with_output(key2, x)
@jax.jit
def f(x):
return model.apply(params, x)
for _ in range(3):
rng, key = jax.random.split(rng)
x = jax.random.normal(key, (32, 2))
f(x).block_until_ready() |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
ikostrikov
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I found a solution with jax.util.cache: