Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] add triton fused moe kernel for gptq/awq #12185

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
91 changes: 91 additions & 0 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
Expand Down Expand Up @@ -55,6 +57,95 @@ def test_fused_moe(
rtol=0)


@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int, has_zp: bool,
weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)

if weight_bits == 4:
pack_factor = 2
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
elif weight_bits == 8:
pack_factor = 1
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size),
device="cuda",
dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda",
dtype=torch.uint8)

for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
else:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

w_ref[expert_id] = weight
w_qweight[expert_id] = qweight
w_scales[expert_id] = scales
if has_zp:
w_qzeros[expert_id] = qzeros

triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
Expand Down
Loading
Loading