Skip to content

How to wrap a Module with a user-defined container object (e..g, dataclass) in Flax/JAX? #1215

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @jheek:

This works:

class DC(struct.PyTreeNode):
  hidden: Any

class MLP(nn.Module):
  dc: Any

  def __call__(self, x):
    return self.dc.hidden(x)

dc = DC(nn.Dense(2))
mlp = MLP(dc)
x = jnp.ones((3,))
mlp.init(random.PRNGKey(0), x)

PyTreeNode is the same as @flax.struct.dataclass but it's better behaved wrt to PyType.

Note that if hidden does not have any type annotation, it is a static field of DC and it won't work. This is somewhat weird but that's just Python syntax.

Replies: 1 comment

Comment options

marcvanzee
Apr 8, 2021
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant