Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP8][Kernel] Dynamic kv cache scaling factors computation #11906

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

gshtras
Copy link
Contributor

@gshtras gshtras commented Jan 9, 2025

This PR deprecates loading kv cache scales from json in favor of adding the option to dynamically compute them based on the first real input to the attention layer.
Our tests showed that the dynamic range computed based on the first input to each layer is representative of the entire model, and the accuracy is comparable with scaling factors computed using Quark quantizer (such as in HF amd/*-FP8-KV models)

Accuracy measured using the P3L benchmark that allows measuring accuracy on decode steps, using the data in the kv cache

K and V scale parameters are made on-device tensors in order to allow changing their values after the graph has been captured. This also lays the foundation to using per-channel quantization with tensor-like scales.

The effect is most visible on models with dynamic value ranges outside of the scope of fp8e4m3, such as Quen2 7B:
Using dynamic calculation reduces the PPL score from 34.84 to 22.62

On LLama based models the improvement is much smaller, due to the fact that identity scales work just as well, but still can be in single digit percents, on par with using the scales from a quantized model

micah-wil and others added 10 commits December 18, 2024 13:40
…ching (#317)

* Changed _k_scale and _v_scale to tensors

* fixed rocm paged attention with tensor kv scales

* Added on the fly scale factor calculation

* trying to fix attn metadata

* fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py

* Changed K and V scale constants

* Removed unneeded comment

* Changes to pass format.sh, also fixed lingering k_scale/v_scale : float

* Fix for TP > 1

* Ran format.sh

* Removed legacy kv_scale loading from the json file

* Removed the outdated kv cache docs

* Revert some unwanted changes

---------

Co-authored-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
* Using tensors in the explicit cache function calls from mllama implementation

* Properly creating the tensor

Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Copy link

github-actions bot commented Jan 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Jan 9, 2025
@mgoin mgoin self-requested a review January 13, 2025 20:23
Copy link

mergify bot commented Jan 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 13, 2025
@mergify mergify bot removed the needs-rebase label Jan 13, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a great feature to add (and removal of unused pathway). I do think it is a bit opinionated to only calculate the dynamic scales on the first inference - at the least we should document the recommended setup to properly prime the scales. Possibly we could have a mode where we always use dynamic scaling or require N tokens seen before stopping calibration.

The choice of when each attention backend is passing in hardcoded enable_kv_scales_calculation=True vs enable_kv_scales_calculation=False seems very unclear at a glance if you could add comments for that. We also should keep the error checking for backends that don't support quantization at all.

Comment on lines 30 to +32
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf",
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
"meta-llama/Llama-2-7b-chat-hf")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this case can be removed

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
@@ -181,6 +182,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded to True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This metadata field is used to disable the calculation on profile and graph capture stages (assuming it is already globally enabled), to not prime the scales with dummy data. So it's on by default (on platforms that support it), and is switched off during these stages.
It does not control the feature at large, for that there is a config parameter

Copy link

mergify bot commented Jan 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 15, 2025
@gshtras
Copy link
Contributor Author

gshtras commented Jan 15, 2025

I think this is a great feature to add (and removal of unused pathway). I do think it is a bit opinionated to only calculate the dynamic scales on the first inference - at the least we should document the recommended setup to properly prime the scales. Possibly we could have a mode where we always use dynamic scaling or require N tokens seen before stopping calibration.

The choice of when each attention backend is passing in hardcoded enable_kv_scales_calculation=True vs enable_kv_scales_calculation=False seems very unclear at a glance if you could add comments for that. We also should keep the error checking for backends that don't support quantization at all.

An issue with it is that once you have used a certain scale once to put values in the cache, you need to use it from that point on to retrieve that data, so if we choose to calibrate it over multiple steps, we'll have to re-cache everything in each such step
Our measurements show that this data is much less sensitive to the input than to model weights, so one calibration step already provides a significant improvement, comparable to using offline-calculated scales

@mergify mergify bot removed the needs-rebase label Jan 15, 2025
@mgoin mgoin added quantization ready ONLY add when PR is ready to merge/full CI is needed labels Jan 21, 2025
@mgoin
Copy link
Member

mgoin commented Jan 21, 2025

Sorry for enabling the tests so late, but it seems there are several valid errors. Could you please look into updating the tests as well?

Signed-off-by: Gregory Shtrasberg <[email protected]>
@mgoin mgoin enabled auto-merge (squash) January 21, 2025 18:34
Signed-off-by: Gregory Shtrasberg <[email protected]>
auto-merge was automatically disabled January 21, 2025 21:05

Head branch was pushed to by a user without write access

…ds that don't support tensors (Flashinfer), since on CUDA during graph capturing phase referencing tensor values is impossible

Signed-off-by: Gregory Shtrasberg <[email protected]>
Copy link

mergify bot commented Jan 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 22, 2025
@mergify mergify bot removed the needs-rebase label Jan 22, 2025
Copy link

mergify bot commented Jan 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 22, 2025
@mergify mergify bot removed the needs-rebase label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation quantization ready ONLY add when PR is ready to merge/full CI is needed rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants