-
I would like to use linen.Dropout inside of a scanned module. Below is a minimal example based on the linen.scan documentation (here is a colab version). import flax.linen as nn
from jax import random
import jax.numpy as jnp
lstm_dropout = True
class MyLSTMCell(nn.Module):
@nn.compact
def __call__(self, c, xs):
c = nn.Dropout(rate=0.1, deterministic=not lstm_dropout)(c)
return nn.LSTMCell()(c, xs)
class SimpleScan(nn.Module):
@nn.compact
def __call__(self, c, xs):
xs = nn.Dropout(rate=0.1, deterministic=False)(xs)
LSTM = nn.scan(MyLSTMCell,
variable_broadcast='params',
split_rngs={'params': False, 'dropout': True})
return LSTM()(c, xs)
rng = random.PRNGKey(0)
xs = jnp.ones((16, 10))
carry_0 = nn.LSTMCell.initialize_carry(rng, (16,), 10)
model = SimpleScan()
variables = model.init({'params': rng, 'dropout': rng}, carry_0, xs)
_, out_val = model.apply(variables, carry_0, xs, rngs={'dropout': rng}) This produces the following error (the full stack trace is in the linked colab): <ipython-input-2-0e8790d078ad> in __call__(self, c, xs)
8 @nn.compact
9 def __call__(self, c, xs):
---> 10 c = nn.Dropout(rate=0.1, deterministic=not lstm_dropout)(c)
11 return nn.LSTMCell()(c, xs)
12
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
273 _context.module_stack.append(self)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
277 filter_fn = _context.capture_stack[-1]
/usr/local/lib/python3.7/dist-packages/flax/linen/stochastic.py in __call__(self, inputs, deterministic, rng)
64 if rng is None:
65 rng = self.make_rng('dropout')
---> 66 broadcast_shape = list(inputs.shape)
67 for dim in self.broadcast_dims:
68 broadcast_shape[dim] = 1
AttributeError: 'tuple' object has no attribute 'shape' If you set |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
This is a great example of why we should probably start using You are feeding a tuple to dropout while it expects an JAX array. The state of an LSTM is a tuple pair with the memory and hidden state. You should be able to do the following instead:
|
Beta Was this translation helpful? Give feedback.
-
Turns out the minimal example does not reproduce my original issue. The real problem was using Dropout + linen.scan + jax.mask, since the RNG would need to be split to the length of the mask (which is not allowed), so you get To fix this, you must split the RNGs in advance and pass them into masked function. This colab contains the full problem and my solution. |
Beta Was this translation helpful? Give feedback.
This is a great example of why we should probably start using
chex
to validate inputs consistently.You are feeding a tuple to dropout while it expects an JAX array. The state of an LSTM is a tuple pair with the memory and hidden state. You should be able to do the following instead: