-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] add triton fused moe kernel for gptq/awq #12185
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
91b41c6
to
87e191f
Compare
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
21c1d8d
to
99f23f2
Compare
Signed-off-by: Jinzhen Lin <[email protected]>
55102d9
to
15ae02b
Compare
@mgoin @robertgshaw2-redhat Could we expedite this PR + #12036 (not sure if #12204 is needed too or has overlap) now that DeepSeek has released their full lineup? |
I created a new PR with better |
I think this PR could be closed in favor of #12222. Thanks for your work @jinzhen-lin |
#12222 is an optimiztion over #12036 or #12204, it can be combined with this PR to get a better performance. |
Thank you for the work! We will take a look now |
Considering that this is allowing for "another option" to run quantized moe models, maybe we should consider writing a documentation page specifically for moe quantization. I think the best case for this kernel to be used more broadly would be to have a heuristic on the number of experts or some configuration to decide whether to use the triton or marlin kernel |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
I test with small moe model (https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4) just now, triton kernel seems much faster than marlin kernel too. Besides, marlin kernel seems generate wrong result for Test result on A100 * 1: marlin kernel:
triton kernel
Maybe we should set triton kernel as default moe gptq/awq kernel? But I am not sure how to do this, gptq-marlin-moe is a part of gpt-marlin quanzation method, if I change moe kernel of gptq-marlin method, user cannot use gptq-marlin-moe anyway. Is that ok? |
Signed-off-by: Jinzhen Lin <[email protected]>
Thank you for benchmarking. I think this case is still exercising the scenario where the experts are small, so if you could benchmark a more where the experts are few and large such as Mixtral 8x7B or 8x22B, I would feel more confident towards using this kernel by default. I think clearly we can land this as-is and treat this as opt-in for the moment. We can followup later if we want to use this by default in all cases or based on a heuristic. |
@jinzhen-lin I tried loading an awq mixtral model and it failed to pass the right kwargs through to AWQMarlin
Since many of the config initializers don't have extra kwargs, you will likely need to check the named args of each initializer to attempt to prune the full_config before passing in unzipped |
Sorry, the commit serveral hours ago introduced this bug. It should be |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Thank you, it seems to work fine now. I ran a 128/128 benchmark at 10QPS for the mixtral awq model on H100 and found that the marlin kernels were more performant. In the future we could make a heuristic to choose the kernel based on model configuration, but for now let's keep the kernel as opt-in awq_marlin
moe_wna16
|
This makes sense since Mixtral has few experts, excited to get this into main to test it out! The main thing I see optimized here is the number of kernel launches. It should still be more performant higher number of experts, not sure where the exact threshold is but 32 or 64 is probably a good minimum. |
The current only option for using moe+gptq/awq is the Marlin kernel, but for the Marlin kernel, a single
marlin_gemm_moe
would launchingnum_experts
CUDA kernels at least, while the fused_moe triton kernel only needs to launch one cuda kernel. This makes the Marlin kernel significantly slower than the fused_moe triton kernel.This PR adds support for fused_moe triton kernel with gptq/awq.
Generation speed of deepseek-v3-awq (8*A100-SXM4-80GB, bs=1, short prompt)
Note:
moe_align_block_size
kernel support for deepseek-v3