Skip to content

Cannot use linen.Dropout inside linen.scan #1253

Answered by jheek
n2cholas asked this question in Q&A
Discussion options

You must be logged in to vote

This is a great example of why we should probably start using chex to validate inputs consistently.

You are feeding a tuple to dropout while it expects an JAX array. The state of an LSTM is a tuple pair with the memory and hidden state. You should be able to do the following instead:

dropout = nn.Dropout(rate=0.1, deterministic=not lstm_dropout)
c = jax.tree_map(dropout, c)

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
1 reply
@n2cholas
Comment options

Answer selected by n2cholas
Comment options

You must be logged in to vote
3 replies
@jheek
Comment options

@n2cholas
Comment options

@jheek
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #1252 on April 16, 2021 07:54.