Skip to content

Masked linear for MADE #1231

Answered by ikostrikov
ikostrikov asked this question in Q&A
Discussion options

You must be logged in to vote

I found a solution with jax.util.cache:

import flax.linen as nn
import jax
import jax.numpy as jnp
import enum
from jax.experimental.host_callback import id_print


class MaskType(enum.Enum):
    input = 1
    hidden = 2
    output = 3

@jax.util.cache()
def get_mask(input_dim: int, output_dim: int, randvar_dim: int,
             mask_type: MaskType) -> jnp.DeviceArray:
    """
    Create a mask for MADE.
    See Figure 1 for a better illustration:
    https://arxiv.org/pdf/1502.03509.pdf
    Args:
        input_dim: Dimensionality of the inputs.
        output_dim: Dimensionality of the outputs.
        rand_var_dim: Dimensionality of the random variable.
        mask_type: MaskType.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ikostrikov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant