-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Add attention sinks #3515
base: main
Are you sure you want to change the base?
Conversation
Hi, @felixzhu555 . it is https://arxiv.org/abs/2309.17453 right? |
Yep, trying to implement the logic from that paper. Their repo is https://github.com/mit-han-lab/streaming-llm. |
We need to @rlouf to the PR the guy in charge of outline, it seems that your PR is failing on the guided part. |
@felixzhu555 Hi, this is pr still in progress and should I expect it will be merged? |
hi @hustxiayang, sorry this PR likely won't get merged, it remains an experimental prototype based on an older version of vLLM. After the ongoing engine refactor is complete, the memory manager in vllm should become more extensible and attention sinks can be supported more easily, at which time we can probably open a new PR. |
@felixzhu555 thanks a lot for your clarification! |
Overview
This PR adds experimental support for attention sinks (#1304), based on this paper and repo. Support is currently limited to RoPE and ALiBi models (e.g. Llama, Mistral/Mixtral, Falcon, Bloom, MPT). The attention sink is hard-coded as the first block of tokens in a sequence.
Usage
Set
use_attention_sinks=True
when instantiatingLLM
orLLMEngine
, or set the--use-attention-sinks
CLI argument. Also setenforce_eager=True
(attention sinks currently does not work with CUDA graphs), and ensure the attention backend being used is FlashAttention, XFormers, or FlashInfer (WIP).Background
Experiments show that the attention mechanism heavily attends to the first few tokens of the sequence being completed, regardless of what the tokens are. Once sequence length exceeds the context length of a model, and we start evicting tokens from the beginning of the KV cache (in a sliding window fashion), the model will generate garbage (high perplexity).
This is where attention sinks come in. By always preserving the KVs for the first few tokens of the sequence while using a sliding window approach for the rest of the KV cache, the model can continue to generate sensible output (low perplexity). Theoretically, the model can stream indefinitely, as long as cache eviction is handled properly. Note the sliding window length is the model's context length.
Example
Suppose our model's context length is 2048, which equals 128 blocks of 16 tokens. Let's pass in a prompt of 2000 tokens. For the next 48 generated tokens, nothing changes; we end up filling 128 blocks so far.
Normally, vLLM forces generation to stop here since the model's context length has been reached. However, using attention sinks we bypass this stopping condition and keep generating.
At the next decode, we are writing the 2049th token to the cache and computing the 2050th token (1-based indexing). Here, we edit the block table to be
[block_table[0]] + block_table[2:]
, where we effectively ignore the 2nd block while retaining the 1st block, which is our attention sink. Notice how the block table is still length 128 because the 129th block was just allocated for token 2049. This modified block table is then used in the attention kernel.Every 16th decode that follows will ignore an additional block, but always retain the 1st block as the sink.
Modifications
This PR adds a
StreamingAttentionSink
layer that computes attention using modified block tables with the "sink" block concatenated with the remaining sliding window blocks. In the RoPE case, we always store pre-rope keys into the cache, and extra work must be done at every decode to rotate all keys for a sequence based on their new positions in the cache. Note: due to this extra work, using attention sinks incurs a significant drop in tokens/s for RoPE models (around 50-70% for Llama).use_attention_sinks
is now an argument toLLMEngine
, which passes it to the model runner and injects attention sinks into the model's modules. On every forward call of the model's attention layer, normal attention logic is replaced byStreamingAttentionSink
logic.The scheduler evicts (frees) a block (the "ignored" block) whenever a new block is allocated past the model's context length, such that the total number of used blocks is capped at
max_model_len // block_size
.Future Work
StreamingAttentionSink
assumes only 1 token is generated every decode.StreamingAttentionSink
directly edits the block table for every decode (past the context length), so the hash table for prefix caching cannot be used currently.