How to wrap a Module with a user-defined container object (e..g, dataclass) in Flax/JAX? #1215
-
Original question by @rsepassi: I noticed in the Module.setup method if I have a user-defined container object (e.g. a dataclass) wrapping a Module then flax can't seem to see it and I get "unbound" errors. I think this has to do with pytrees and the like? Does flax have some way to register the user-defined type? i have something like: @dataclasses.dataclass
class MyClass:
module: nn.Module
def setup(self):
self.foo = MyClass(nn.Dense(4))
... |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Answer by @jheek: This works: class DC(struct.PyTreeNode):
hidden: Any
class MLP(nn.Module):
dc: Any
def __call__(self, x):
return self.dc.hidden(x)
dc = DC(nn.Dense(2))
mlp = MLP(dc)
x = jnp.ones((3,))
mlp.init(random.PRNGKey(0), x)
Note that if |
Beta Was this translation helpful? Give feedback.
Answer by @jheek:
This works:
PyTreeNode
is the same as@flax.struct.dataclass
but it's better behaved wrt to PyType.Note that if
hidden
does not have any type annotation, it is a static field ofDC
and it won't work. This is somewhat weird but that's just Python syntax.