How to do autoregressive decoding in JAX/Flax? #920
-
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 8 replies
-
Autoregressive decoding is implemented in attention.py. It is used in the WMT example Some more explanation from @levskaya, who wrote most of the code: You only need caching for the self-attention layers in a decoder. (not for an encoder, and not for the encoder-decoder layers found in a full enc-dec transformer model) let's say you have a decoder-stack (i.e. a language model) - if you run a sampler (beam, top-k, etc.) iteratively on the decoder at inference time you pass the first token through, embed it, run the layers, get logits for the next position, then sample that to choose a next-token. Then you stick that token back into the decoder-input, and run it again to get the next token, etc. If you just do this naively, you are re-calculating the past keys and values and past attn interactions for all the past tokens again and again and again, leading to a O(L^2) algorithm that is very very wasteful for longer sequences. So instead of doing it naively, and feeding entire length-L arrays through each layer, we only feed a length 1 array through, and we store the previously calculated keys and values at each self-attention layer in stateful "cache" variables. For future tokens, the "query" at that new timepoint can then attend to the past cached keys, values, and you avoid the L^2 blowup In practice, at each layer you also store an integer 'index' encoding what position you're currently at. You also create a similar cached index for the absolute position encoding layer to keep track of the current position in the decoder. So the 'cache' is just a set of 0-arrays for the keys and values at each decoder self-attention layer. The length dimension in these is initialized to be equal to the longest-possible sequence you plan to generate. NB: the cached attention implementation in the linen transformer layers was written to be as simple as possible for readability, but we'll probably update it soon to use a slightly more complicated layout and "one-hot scatter" pattern that performs much faster on TPUs. |
Beta Was this translation helpful? Give feedback.
-
Hi, This is a question from the In the decoding in the lm1b example, each new token attends to all the positions up to max decode length, where the future tokens are masked out using a boolean mask. I'm guessing that the self attention in flax is is more optimized for TPU, but what would be a good way to handle this on CPU/GPU? |
Beta Was this translation helpful? Give feedback.
-
Just to shed some more light on what @jheek and @levskaya are pointing out as well: The attention function will be jit compiled for each shape that you input. So if you are planning to implement attention with a dynamically increasing index, then this means the attention function will be jit compiled for all sequence lengths. In terms of computation cost, this means you will spend a lot of time compiling all the attention functions, which is infeasible in practice. |
Beta Was this translation helpful? Give feedback.
-
Here's a notebook for some benchmarking: This is running times for an overly simplified encoder decoder Transformer (3 layers) on CPU:
Here, the Notably, the compile time is indeed inverse to Remaining works:
|
Beta Was this translation helpful? Give feedback.
-
One thing to be aware of, especially for XLA:TPU but also XLA:GPU is that though the high-level IR of XLA (HLO) offers only a static-size API, underneath when XLA lowers to machine code it does quite a few transformations that can internally do smart, dynamic things. For instance, you can write "one-hot" gathers and scatters using dots that look horrendous from a compute/memory perspective that are optimized out into efficient code on TPUs. I'm not sure XLA is going to do the optimal thing in this particular case... but it might not be as bad as expected. In any case, it's useful to directly measure things on TPU, since you can't directly reason about what's going on as you might w. low level CPU programming, and the performance cliffs live in very different regimes / places on TPUs than CPUs or even GPUs. |
Beta Was this translation helpful? Give feedback.
Autoregressive decoding is implemented in attention.py. It is used in the WMT example
Some more explanation from @levskaya, who wrote most of the code:
You only need caching for the self-attention layers in a decoder. (not for an encoder, and not for the encoder-decoder layers found in a full enc-dec transformer model)
let's say you have a decoder-stack (i.e. a language model) - if you run a sampler (beam, top-k, etc.) iteratively on the decoder at inference time you pass the first token through, embed it, run the layers, get logits for the next position, then sample that to choose a next-token. Then you stick that token back into the decoder-input, and run it again to get the next token,…