Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Aug 4, 2024
1 parent 0538dcc commit 0ba00ab
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 23 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"fused_topk",
"get_config_file_name",
"grouped_topk",
]
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def fused_experts(hidden_states: torch.Tensor,
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
#assert w1.is_contiguous(), "Expert weights1 must be contiguous"
#assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/layers/fused_moe/fused_moe_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger

from .fused_moe import fused_experts, moe_align_block_size
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, moe_align_block_size)

logger = init_logger(__name__)

Expand Down Expand Up @@ -43,12 +43,11 @@ def fused_experts_awq(
# If large seq_len prefill, dequantize and use the fp16 MoE kernel.
do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD
if do_naive_dequant:
# TODO: why is this not contiguous already?
# from @dsikka: because of the permutation operation
# NOTE: not contiguous because of the permutation operation
dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0,
0).permute(0, 2, 1)
0).permute(0, 2, 1).contiguous()
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0,
0).permute(0, 2, 1)
0).permute(0, 2, 1).contiguous()

return fused_experts(hidden_states, dequant_w1, dequant_w2,
topk_weights, topk_ids)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def weight_loader(self, param: torch.nn.Parameter,

# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
is_transposed = getattr(param, "is_transposed", False)
input_dim = 0 if is_transposed else 1
output_dim = 1 if is_transposed else 0
# Default is not transposed/input dim is dim 1
input_dim = getattr(param, "input_dim", 1)
output_dim = getattr(param, "output_dim", 0)

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
Expand Down
27 changes: 16 additions & 11 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter
Expand All @@ -8,7 +8,7 @@
fused_experts_awq)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs


Expand Down Expand Up @@ -65,9 +65,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)

def get_quant_method(
self, layer: torch.nn.Module,
prefix: str) -> Optional[Union["AWQMoEMethod", "AWQLinearMethod"]]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
Expand Down Expand Up @@ -202,7 +201,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w13_qweight, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand All @@ -217,7 +217,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w2_qweight, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand All @@ -231,7 +232,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, {
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand All @@ -243,7 +245,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, {
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand All @@ -260,7 +263,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w13_qzeros, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand All @@ -275,7 +279,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
w2_qzeros, {
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"is_transposed": True,
"input_dim": 0,
"output_dim": 1,
**extra_weight_attrs
})

Expand Down

0 comments on commit 0ba00ab

Please sign in to comment.