Skip to content

Commit

Permalink
[Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#6289)

Signed-off-by: Alvant <[email protected]>
  • Loading branch information
WoosukKwon authored and Alvant committed Oct 26, 2024
1 parent fc19a3c commit a95000a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
48 changes: 35 additions & 13 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -36,7 +36,7 @@ def apply(self,
raise NotImplementedError


class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand All @@ -61,19 +61,37 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)

def forward_cuda(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
w1,
w2,
router_logits,
top_k,
renormalize=renormalize,
Expand All @@ -82,6 +100,10 @@ def apply(self,
num_expert_group=num_expert_group,
topk_group=topk_group)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")


class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,6 @@ def load_model(self, *, model_config: ModelConfig,
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()


Expand Down

0 comments on commit a95000a

Please sign in to comment.