Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] AWQ Fused MoE #6422

Closed
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d40fd4d
added files
robertgshaw2-redhat Jul 13, 2024
f1d5836
format
robertgshaw2-redhat Jul 13, 2024
16baf11
stash
robertgshaw2-redhat Jul 13, 2024
03d9d8e
torch library
robertgshaw2-redhat Jul 13, 2024
54d6a87
fixed another torch library
robertgshaw2-redhat Jul 13, 2024
524a94c
first end to end run with tp=1
robertgshaw2-redhat Jul 13, 2024
febb027
loaded but not running at fp16
robertgshaw2-redhat Jul 13, 2024
8bca009
correctness end-to-end!
robertgshaw2-redhat Jul 13, 2024
8527d6e
formatted
robertgshaw2-redhat Jul 13, 2024
36d1d82
updared the weight loading logic
robertgshaw2-redhat Jul 13, 2024
6943e80
stash
robertgshaw2-redhat Jul 13, 2024
71e5129
fixed fp8
robertgshaw2-redhat Jul 13, 2024
703e792
Merge branch 'main' into fused-moe-awq
robertgshaw2-redhat Jul 14, 2024
5b73064
merged
robertgshaw2-redhat Jul 14, 2024
2ef2c92
formatting
robertgshaw2-redhat Jul 14, 2024
db33c3f
better comments
robertgshaw2-redhat Jul 14, 2024
f6f60cd
added
robertgshaw2-redhat Jul 14, 2024
d9def7e
formatted
robertgshaw2-redhat Jul 14, 2024
16eacd0
stash
robertgshaw2-redhat Jul 14, 2024
0674d2f
Merge branch 'main' into fused-moe-awq
dsikka Jul 30, 2024
d6a032e
clean-up, fix tests
dsikka Jul 30, 2024
8d52ae5
normalize weights to prevent illegal memory
dsikka Aug 1, 2024
c08a5da
all MoE tests working
dsikka Aug 1, 2024
7325e78
revert to reproduce error
dsikka Aug 2, 2024
0538dcc
update to comply with main
dsikka Aug 2, 2024
0ba00ab
PR comments
dsikka Aug 4, 2024
419eb7d
fix tpu forward pass; use kwargs
dsikka Aug 5, 2024
5666fcb
fix triton import
dsikka Aug 6, 2024
8013ad4
further fix imports
dsikka Aug 6, 2024
be34dc0
fix
dsikka Aug 6, 2024
6e7bbf9
fix fp8
dsikka Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int64_t split_k_iters);

torch::Tensor awq_fused_moe(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _topk_weights,
torch::Tensor _sorted_token_ids_ptr,
torch::Tensor _expert_ids_ptr,
torch::Tensor _num_tokens_post_padded,
bool mul_weights, int64_t split_k_iters);

torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int64_t split_k_iters,
Expand Down
424 changes: 417 additions & 7 deletions csrc/quantization/awq/gemm_kernels.cu

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Quantized Grouped GEMM for AWQ.
ops.def("awq_fused_moe", &awq_fused_moe);
ops.impl("awq_fused_moe", torch::kCUDA, &awq_fused_moe);

// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
Expand Down
96 changes: 95 additions & 1 deletion tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import (fused_experts_awq, fused_moe,
fused_topk)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.mixtral import MixtralMoE


Expand Down Expand Up @@ -99,3 +102,94 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score,
topk):
score = torch.softmax(score.float(), dim=-1)
topk_weight, topk_ids = torch.topk(score, topk)
(B, D) = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1],
w2.shape[2] * 8,
dtype=a.dtype,
device=a.device)
topk_ids = topk_ids.view(-1)
topk_weight = topk_weight.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0)
dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0)
r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1))
out[mask] = torch.matmul(r1, dw2).to(out.dtype)
return (out.view(B, -1, w2.shape[2] * 8) *
topk_weight.view(B, -1, 1)).sum(dim=1).half()


@pytest.mark.parametrize("m", [1024, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2, 6])
def test_fused_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
):
# awq requires minimum capability 75
if torch.version.hip is not None:
return
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 75:
return

RANGE = 1000000000
groupsize = 128
a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10
qw1 = torch.randint(-RANGE,
RANGE, (e, k, n * 2 // 8),
dtype=torch.int,
device='cuda')
qw2 = torch.randint(-RANGE,
RANGE, (e, n, k // 8),
dtype=torch.int,
device='cuda')

scale1 = torch.randn(
(e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50
scale2 = torch.randn(
(e, n // groupsize, k), dtype=torch.half, device='cuda') / 50

zero1 = torch.randint(-RANGE,
RANGE, (e, k // groupsize, (n * 2 // 32) * 4),
dtype=torch.int32,
device='cuda')
zero2 = torch.randint(-RANGE,
RANGE, (e, n // groupsize, (k // 32) * 4),
dtype=torch.int32,
device='cuda')
w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1}
w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2}

score = torch.randn((m, e), device='cuda', dtype=torch.half)

quant_config = AWQConfig(4, groupsize, False)
torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2,
score, topk)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
cuda_output = fused_experts_awq(hidden_states=a,
w1=w1["qweight"],
w2=w2["qweight"],
w1_scales=w1["scales"],
w2_scales=w2["scales"],
w1_qzeros=w1["qzeros"],
w2_qzeros=w2["qzeros"],
topk_weights=topk_weights,
topk_ids=topk_ids,
pack_factor=quant_config.pack_factor)
assert torch.allclose(cuda_output, torch_output, atol=1e-2, rtol=0)
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


def awq_fused_moe(input: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, qzeros: torch.Tensor,
topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, num_tokens_post_padded: int,
mul_weights: bool, pack_factor: int) -> torch.Tensor:
return torch.ops._C.awq_fused_moe(input, qweight, scales, qzeros,
topk_weights, sorted_token_ids,
expert_ids, num_tokens_post_padded,
mul_weights, pack_factor)


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from vllm.model_executor.layers.fused_moe.fused_moe_awq import (
fused_experts_awq)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.triton_utils import HAS_TRITON

__all__ = [
"fused_experts_awq",
"FusedMoE",
"FusedMoEMethodBase",
]
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def fused_experts(hidden_states: torch.Tensor,
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
#assert w1.is_contiguous(), "Expert weights1 must be contiguous"
#assert w2.is_contiguous(), "Expert weights2 must be contiguous"
mgoin marked this conversation as resolved.
Show resolved Hide resolved
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
Expand Down
74 changes: 74 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Fused MoE utilities for AWQ."""
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger

from .fused_moe import fused_experts, moe_align_block_size
mgoin marked this conversation as resolved.
Show resolved Hide resolved

logger = init_logger(__name__)

NAIVE_THRESHOLD = 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit high and it is worth commenting how it was calibrated (what model, benchmark, GPU used)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic do we know why this is 1024 specifically?



def fused_experts_awq(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scales: torch.Tensor,
w2_scales: torch.Tensor,
w1_qzeros: torch.Tensor,
w2_qzeros: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
pack_factor: int,
) -> torch.Tensor:
"""
This function computes an AWQ fused_expert.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scales (torch.Tensor): scale to be used for w1.
- w2_scales (torch.Tensor): scale to be used for w2.
- w1_qzeros (torch.Tensor): zero point to be used for w1.
- w2_qzeros (torch.Tensor): zero point to be used for w2.
- pack_factor (int): Weight packing factor (int4 in int32 == 8)

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""

# If large seq_len prefill, dequantize and use the fp16 MoE kernel.
do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD
if do_naive_dequant:
# TODO: why is this not contiguous already?
# from @dsikka: because of the permutation operation
mgoin marked this conversation as resolved.
Show resolved Hide resolved
dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0,
0).permute(0, 2, 1)
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0,
0).permute(0, 2, 1)

return fused_experts(hidden_states, dequant_w1, dequant_w2,
topk_weights, topk_ids)

(sorted_token_ids, expert_ids,
num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0])

x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:])

gate_up = ops.awq_fused_moe(x, w1, w1_scales, w1_qzeros, topk_weights,
sorted_token_ids, expert_ids,
num_tokens_post_padded, False, pack_factor)
mgoin marked this conversation as resolved.
Show resolved Hide resolved

out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
dtype=hidden_states.dtype,
device=hidden_states.device)
ops.silu_and_mul(out, gate_up)

out = ops.awq_fused_moe(out, w2, w2_scales, w2_qzeros, topk_weights,
sorted_token_ids, expert_ids,
num_tokens_post_padded, True, pack_factor)

return torch.sum(out, dim=1)
Loading
Loading