From a95000ac19073e1574a47a8aec9ede6ecd660def Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 15 Jul 2024 12:02:14 -0700 Subject: [PATCH] [Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (#6289) Signed-off-by: Alvant --- vllm/model_executor/layers/fused_moe/layer.py | 48 ++++++++++++++----- vllm/model_executor/model_loader/loader.py | 4 -- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3904f3e3d0e76..7f0668601fac3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -36,7 +36,7 @@ def apply(self, raise NotImplementedError -class UnquantizedFusedMoEMethod(FusedMoEMethodBase): +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -61,19 +61,37 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return self.forward(x, layer.w13_weight, layer.w2_weight, + router_logits, top_k, renormalize, + use_grouped_topk, num_expert_group, topk_group) + def forward_cuda( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe return fused_moe(x, - layer.w13_weight, - layer.w2_weight, + w1, + w2, router_logits, top_k, renormalize=renormalize, @@ -82,6 +100,10 @@ def apply(self, num_expert_group=num_expert_group, topk_group=topk_group) + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError( + "The CPU backend currently does not support MoE.") + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 60547965063fa..0b269393294ae 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -279,10 +279,6 @@ def load_model(self, *, model_config: ModelConfig, quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() return model.eval()