Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Support sparse KV cache framework #5751

Open
chizhang118 opened this issue Jun 21, 2024 · 16 comments
Open

[RFC]: Support sparse KV cache framework #5751

chizhang118 opened this issue Jun 21, 2024 · 16 comments

Comments

@chizhang118
Copy link

chizhang118 commented Jun 21, 2024

Motivation

For current large model inference, KV cache occupies a significant portion of GPU memory, so reducing the size of KV cache is an important direction for improvement. Recently, several papers have approached this issue from different angles, detailed comparison in the table, including:

  • FastDecode: This method offloads all computation of KV cache to the CPU. The computation and storage of KV cache occurs on CPU.

  • Compression methods based on quantization (GEAR, Mixed Precision): By applying various quantization techniques, the size of individual token KV caches is reduced without decreasing the number of tokens stored in the KV cache. This method may also result in corresponding residual and outlier matrices, which need to be stored in memory but not in the KV cache. It may also involve quantizing unimportant token KV caches to reduce the memory footprint of the KV cache.

  • Partial KV cache eviction (H2O, SnapKV, LESS, Adaptive Compression, Scissorhands, Dynamic Memory Compression, StreamingLLM): By removing some relatively useless KV cache entries, the memory footprint of the KV cache is reduced. Essentially, this reduces the number of tokens stored in the KV cache without reducing the size of individual token KV caches.

When addressing the sparse KV cache issue, we have previously considered supporting quantization (VLLM has already implemented this), implementing quantization + outlier + residual like GEAR (not widely applicable as it requires generating outlier and residual for each token generation, which is costly), and implementing KV cache accumulation + appendix (not widely applicable as it requires models to be trained using the same method). Finally, the idea is to implement partial KV cache eviction, primarily aiming for generality and abstraction rather than being specific to one or two approaches. Considering that six of the sparse KV cache methods we found are based on evicting cache entries, this method is also suitable for modification as part of a framework to be integrated into VLLM.

Sparse KV Cache Workflow

First, let's clarify the required parameters, including:

  • An optional flag "--sparse-kv-cache-type" indicating if we want to specify any sparse KV cache type. Default is ‘auto’ without using any sparse KV cache type, otherwise, there could be various methods, such as attention scores for H2O.

  • Compression ratio for evicting KV cache entries: 20% if we want to achieve 80% reduction of KV cache usage. We can calculate the value of 'n' for recreating KV cache every 'n' step based on the compression ratio.

The entire workflow includes:

  • During the first decoding pass, besides computing the KV values for all input tokens, we also need to calculate and retain information about the priority ranking of all token pairs, such as attention scores in H2O.

  • During each scheduling of VLLM, we need to check whether 'n' steps have been completed, indicating the necessity for KV cache compression. If necessary, based on the priority ranking of tokens, one or more new KV cache blocks will be allocated, modifying the position information of input positions. The block manager will then manage the transfer of corresponding KV blocks from the original sequence group to the latest KV block. Finally, the reference count of the original KV block will be decremented, and the corresponding original KV blocks may even be released.

  • The corresponding KV values are added to the KV cache until the next compression of the KV cache after 'n' steps, repeating this process until the entire process is completed.

Proposed Change

Modified files mainly include

  • Modify vllm/core/scheduler.py: Add the corresponding logic for checking if sparse KV cache actions should be taken or not.

  • Modify vllm/core/block_manager_v1.py: Add the corresponding logic for updating block table mapping and manage the related allocated/free blocks.

  • Modify vllm/worker/model_runner.py: Update the position related code after sparse KV cache and pass the blocks_to_sparse_copy to the corresponding models.

  • Modify models, such as vllm/model_executor/models/opt.py: Indicating which KV should be filtered out.

  • Modify csrc/attention/attention_kernels.cu, csrc/cache_kernels.cu: Calculate attention score for selecting "important" tokens' KV and support sparse_cache_copy for copying "important" tokens' KV.

PR

PR link: #5752

Design doc

https://docs.google.com/document/d/13_cpb31P9VOmPGa_tZ70s7z1vXGP_UenXf1WVuIppCk/

Feedback Period.

No response

CC List.

@simon-mo @youkaichao @zhuohan123 @cadedaniel @ywang96 @WoosukKwon @LiuXiaoxuanPKU

Any Other Things.

No response

@robertgshaw2-redhat
Copy link
Collaborator

Very exciting!

@thesues
Copy link
Contributor

thesues commented Jun 21, 2024

how many gpu memory can be saved? do you have any benchmark data?

@chizhang118
Copy link
Author

how many gpu memory can be saved? do you have any benchmark data?

This depends on the Sparse KV cache compression ratio, from current paper, 20% compression ratio is a rough number, which means 80% reduction. Now is pending feedback from community, there is no benchmark data yet.

@chizhang118 chizhang118 changed the title [RFC]: Support sparse KV cache [RFC]: Support sparse KV cache framework Jun 21, 2024
@Zefan-Cai
Copy link

Would you mind adding newly-proposed KV cache compression methods other than SnapKV and H2O? (i.e. PyramidKV)

@chizhang118
Copy link
Author

Would you mind adding newly-proposed KV cache compression methods other than SnapKV and H2O? (i.e. PyramidKV)

Sure, it should not be difficult to add based on the current framework. Will be on my radar. Thanks!

@Zefan-Cai
Copy link

Would you mind adding newly-proposed KV cache compression methods other than SnapKV and H2O? (i.e. PyramidKV)

Sure, it should not be difficult to add based on the current framework. Will be on my radar. Thanks!

Super cool! Thank you so much for your efforts!

@simon-mo
Copy link
Collaborator

This is exciting indeed. Few things

@Zefan-Cai
Copy link

Would you mind adding newly-proposed KV cache compression methods other than SnapKV and H2O? (i.e. PyramidKV)

Sure, it should not be difficult to add based on the current framework. Will be on my radar. Thanks!

Would you mind @ me when the new method is added? can't wait to have a try with vLLM!

@dongxiaolong
Copy link

dongxiaolong commented Jul 4, 2024

https://github.com/microsoft/MInference
Is there a combination of dynamic sparse attention and sparse KV cache?
The vllm implementation is provided here

@Zefan-Cai
Copy link

https://github.com/microsoft/MInference Is there a combination of dynamic sparse attention and sparse KV cache? The vllm implementation is provided here

This repo does not provide sparse KV cache implementation in vLLM. They only provide HF ones.

@dongxiaolong
Copy link

https://github.com/microsoft/MInference Is there a combination of dynamic sparse attention and sparse KV cache? The vllm implementation is provided here

This repo does not provide sparse KV cache implementation in vLLM. They only provide HF ones.

for vLLM,

from vllm import LLM, SamplingParams

  • from minference import MInference

llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)

Patch MInference Module

+minference_patch = MInference("vllm", model_name)
+llm = minference_patch(llm)

outputs = llm.generate(prompts, sampling_params)
using only the kernel,

from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward

attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
attn_output = block_sparse_attention(q, k, v, topk)
attn_output = streaming_forward(q, k, v, init_num, local_window_num)
For more details, please refer to our Examples and Experiments. You can find more information about the dynamic compiler PIT in this paper and on GitHub.

@Zefan-Cai
Copy link

https://github.com/microsoft/MInference Is there a combination of dynamic sparse attention and sparse KV cache? The vllm implementation is provided here

This repo does not provide sparse KV cache implementation in vLLM. They only provide HF ones.

for vLLM,

from vllm import LLM, SamplingParams

  • from minference import MInference

llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)

Patch MInference Module

+minference_patch = MInference("vllm", model_name) +llm = minference_patch(llm)

outputs = llm.generate(prompts, sampling_params) using only the kernel,

from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward

attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) attn_output = block_sparse_attention(q, k, v, topk) attn_output = streaming_forward(q, k, v, init_num, local_window_num) For more details, please refer to our Examples and Experiments. You can find more information about the dynamic compiler PIT in this paper and on GitHub.

Are you an author of this repo? Your attached code seems not containing sparse kv cache implementation. and the Examples folder neither. Do I miss something?

@dongxiaolong
Copy link

https://github.com/microsoft/MInference Is there a combination of dynamic sparse attention and sparse KV cache? The vllm implementation is provided here

This repo does not provide sparse KV cache implementation in vLLM. They only provide HF ones.

for vLLM,
from vllm import LLM, SamplingParams

  • from minference import MInference

llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)

Patch MInference Module

+minference_patch = MInference("vllm", model_name) +llm = minference_patch(llm)
outputs = llm.generate(prompts, sampling_params) using only the kernel,
from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward
attn_output = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) attn_output = block_sparse_attention(q, k, v, topk) attn_output = streaming_forward(q, k, v, init_num, local_window_num) For more details, please refer to our Examples and Experiments. You can find more information about the dynamic compiler PIT in this paper and on GitHub.

Are you an author of this repo? Your attached code seems not containing sparse kv cache implementation. and the Examples folder neither. Do I miss something?

an

I am not the author of this repo. It's not sparse kv cache, it's sparse attention. Isn't there something in common?

@PatchouliTIS
Copy link

Great work! However, I noticed that your implementation only adapts for memory-friendly attention for xformers. Do you think it would be a lot of work to adapt it for Flash-Attention 2 with the current architecture? Or do you have plans to adapt for FlashAttention 2 in the future?
https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/flash_attn.py

@PatchouliTIS
Copy link

btw, I tried long prompt in your framework, found that in long prompt scenario (approximately 3k tokens) the outputs make no sense just repeat some tokens to its outputs limit. I think maybe it is related to the sparse kv implementation?

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants