Skip to content

How to do autoregressive decoding in JAX/Flax? #920

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Autoregressive decoding is implemented in attention.py. It is used in the WMT example

Some more explanation from @levskaya, who wrote most of the code:

You only need caching for the self-attention layers in a decoder. (not for an encoder, and not for the encoder-decoder layers found in a full enc-dec transformer model)

let's say you have a decoder-stack (i.e. a language model) - if you run a sampler (beam, top-k, etc.) iteratively on the decoder at inference time you pass the first token through, embed it, run the layers, get logits for the next position, then sample that to choose a next-token. Then you stick that token back into the decoder-input, and run it again to get the next token,…

Replies: 5 comments 8 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
4 replies
@levskaya
Comment options

@thisiscam
Comment options

@jheek
Comment options

@thisiscam
Comment options

Comment options

You must be logged in to vote
2 replies
@thisiscam
Comment options

@thisiscam
Comment options

Comment options

You must be logged in to vote
1 reply
@thisiscam
Comment options

Comment options

You must be logged in to vote
1 reply
@thisiscam
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants
Converted from issue

This discussion was converted from issue #869 on January 21, 2021 14:13.