From 3644559b3d703246661fd8b7ad64d090149d0425 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:33:29 -0500 Subject: [PATCH] [TPU][Quantization] TPU `W8A8` (#11785) Co-authored-by: Woosuk Kwon --- .buildkite/run-tpu-test.sh | 11 +- tests/tpu/test_quantization_accuracy.py | 49 +++++++ .../schemes/compressed_tensors_w8a8_int8.py | 105 ++++---------- .../schemes/compressed_tensors_wNa16.py | 2 +- .../layers/quantization/gptq_marlin.py | 2 +- .../layers/quantization/kernels/__init__.py | 74 ---------- .../{ => mixed_precision}/MPLinearKernel.py | 0 .../kernels/mixed_precision/__init__.py | 74 ++++++++++ .../kernels/{ => mixed_precision}/exllama.py | 0 .../kernels/{ => mixed_precision}/machete.py | 0 .../kernels/{ => mixed_precision}/marlin.py | 0 .../kernels/scaled_mm/ScaledMMLinearKernel.py | 64 +++++++++ .../kernels/scaled_mm/__init__.py | 84 +++++++++++ .../quantization/kernels/scaled_mm/cutlass.py | 134 ++++++++++++++++++ .../quantization/kernels/scaled_mm/xla.py | 101 +++++++++++++ .../layers/quantization/utils/w8a8_utils.py | 38 ----- vllm/model_executor/parameter.py | 13 ++ vllm/platforms/tpu.py | 4 +- 18 files changed, 565 insertions(+), 190 deletions(-) create mode 100644 tests/tpu/test_quantization_accuracy.py rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/MPLinearKernel.py (100%) create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/exllama.py (100%) rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/machete.py (100%) rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/marlin.py (100%) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 13605a3e97142..a8f021890f742 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,13 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it \ + -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ + vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ + && python3 -m pip install pytest \ + && python3 -m pip install lm_eval[api]==0.4.4 \ + && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ + && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + && python3 /workspace/vllm/tests/tpu/test_compilation.py \ + && python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py" diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py new file mode 100644 index 0000000000000..6cd5615c44e1e --- /dev/null +++ b/tests/tpu/test_quantization_accuracy.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +import lm_eval +import pytest + +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 + + +@dataclass +class GSM8KAccuracyTestConfig: + model_name: str + excepted_value: float + + def get_model_args(self) -> str: + return (f"pretrained={self.model_name}," + "max_model_len=4096,max_num_seqs=32") + + +# NOTE: Accuracy scores measured on GPUs. +ACCURACY_CONFIGS = [ + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + excepted_value=0.76), # no bias + # NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, + # so only one of these tests can run in a single call to pytest. As + # a follow up, move this into the LM-EVAL section of the CI. + # GSM8KAccuracyTestConfig( + # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", + # excepted_value=0.66), # bias in QKV layers +] + + +@pytest.mark.parametrize("config", ACCURACY_CONFIGS) +def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=config.get_model_args(), + tasks="gsm8k", + batch_size="auto", + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6cbc58d61e970..0e3f4731775c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,14 +1,13 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch from compressed_tensors.quantization import QuantizationStrategy -from torch.nn import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_int8_linear, convert_to_channelwise) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, ModelWeightParameter, @@ -18,6 +17,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool): @@ -30,74 +30,25 @@ def get_min_capability(cls) -> int: # turing and up return 75 - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(self.logical_widths) > 1 - if is_fused_module and self.strategy == QuantizationStrategy.TENSOR: - ws_channelwise = convert_to_channelwise(layer.weight_scale, - self.logical_widths) - layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) - else: - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) - # INPUT SCALE - if self.is_static_input_scheme: - if self.input_symmetric: - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) - layer.input_zero_point = None - else: - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = layer.input_zero_point.to(dtype=torch.int32) - range_max = (layer.input_scale * - (int8_traits.max - azps)).max() - range_min = (layer.input_scale * - (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) - layer.input_scale = Parameter(scale, requires_grad=False) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - - range_min / scale).to(dtype=torch.int32) - layer.input_zero_point = Parameter(azp, requires_grad=False) - - else: - layer.input_scale = None - layer.input_zero_point = None - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md - if not self.input_symmetric: - azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = layer.input_zero_point * azp_adj - - layer.azp_adj = azp_adj - else: - layer.azp_adj = None - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - self.logical_widths = output_partition_sizes + layer.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=self.input_symmetric) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8Int8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT weight = ModelWeightParameter(data=torch.empty( @@ -140,12 +91,18 @@ def create_weights(self, layer: torch.nn.Module, weight_loader=weight_loader) layer.register_parameter("input_zero_point", input_zero_point) + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return apply_int8_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - input_zero_point=layer.input_zero_point, - azp_adj=layer.azp_adj, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index a515738017781..2dd243b9c3109 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -6,7 +6,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.kernels import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a006d729cc627..2dbfca9b07690 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -11,7 +11,7 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.kernels import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 94a3dc2584d6b..e69de29bb2d1d 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,74 +0,0 @@ -from typing import List, Optional, Type - -import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.exllama import ( - ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( - MPLinearKernel, MPLinearLayerConfig) -from vllm.platforms import current_platform - -# in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ - MacheteLinearKernel, - MarlinLinearKernel, - ExllamaLinearKernel, -] - - -def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: - """ - Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of - performance. - - Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. - compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. - - Raises: - ValueError: If no kernel can implement the given config. - - Returns: - Type[MPLinearKernel]: Chosen kernel. - """ - if compute_capability is None: - if current_platform is None: - raise ValueError("Cannot determine compute capability") - _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] - - failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: - failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') - continue - - if kernel.get_min_capability() > compute_capability: - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue - - can_implement, failure_reason = kernel.can_implement(config) - if can_implement: - return kernel - else: - failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' - ) - - raise ValueError( - "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 0000000000000..83549870e3f0b --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -0,0 +1,74 @@ +from typing import List, Optional, Type + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, + ExllamaLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/exllama.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/machete.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/marlin.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py new file mode 100644 index 0000000000000..75cf91f191136 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + + +@dataclass +class ScaledMMLinearLayerConfig: + is_channelwise: bool + is_static_input_scheme: bool + input_symmetric: bool + + +class ScaledMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str, + i_zp_param_name: str, azp_adj_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + self.i_zp_name = i_zp_param_name + self.azp_adj_name = azp_adj_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + getattr(layer, self.i_zp_name), + getattr(layer, self.azp_adj_name), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000000000..586752d3d34e3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -0,0 +1,84 @@ +import os +from typing import Dict, List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( +# TritonScaledMMLinear) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) +from vllm.platforms import PlatformEnum, current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + # TODO(rob): Create TritonScaledMMLinear kernel. ROCM will + # incorrectly attempt to run AZP models if prompted to. + PlatformEnum.ROCM: [CutlassScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} + + +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None +) -> Type[ScaledMMLinearKernel]: + """ + Choose an ScalledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[ScaledMMLinearKernel]: Chosen kernel. + """ + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "ScaledMM linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000000000..2e83a04286a0d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,134 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.config.input_symmetric: + weight = getattr(layer, self.w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = i_zp is not None + azp = None if static else x_zp + return ops.cutlass_scaled_mm_azp(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py new file mode 100644 index 0000000000000..9de668e658826 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -0,0 +1,101 @@ +import warnings +from typing import Optional, Tuple + +import torch +from functorch.experimental.control_flow import cond # noqa: F401 + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class XLAScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "TPU platform does have a concept of compute capability, " + "this method should not be called.") + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if not current_platform.is_tpu(): + return False, "ScaledMMXLA requires running on TPU." + + if c.is_static_input_scheme: + return False, "ScaledMMXLA requires dynamic activation scales." + + if not c.input_symmetric: + return False, "ScaledMMXLA requires symmetric activation scales." + + if not c.is_channelwise: + return False, "ScaledMMXLA requires channelwise weight scales" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # [out, in] (different than cutlass_scaled_mm) + weight = getattr(layer, self.w_q_name) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) + + # WEIGHT SCALE + # XLA kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + + # [out_channel,] (different than cutlass_scaled_mm) + weight_scale = weight_scale.squeeze(-1) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # Only support symmetric dynamic activation quantization. + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + # Filter warning for cond usage in apply_weights. It is okay + # to specialize the graph since bias is not dynamic. + warnings.filterwarnings( + "ignore", + message= + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + ) + + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + + def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + bias + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + out = torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) + + # Explicitly capture control flow to make dynamo happy. + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d89071f30a549..7cdce67cf1677 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -201,44 +201,6 @@ def apply_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) -def apply_int8_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - input_zero_point: Optional[torch.Tensor] = None, - azp_adj: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -): - # ops.scaled_int8_quant supports both dynamic and static quant. - # * dynamic, layer.input_scale is None and x_scale computed from x. - # * static, layer.input_scale is scalar and x_scale is input_scale. - symmetric = azp_adj is None - x_q, x_scale, x_zp = ops.scaled_int8_quant(input, - input_scale, - input_zero_point, - symmetric=symmetric) - - if x_zp is not None: - # Currently, static is always per-tensor and dynamic is per-token - static = input_zero_point is not None - azp = None if static else x_zp - return ops.cutlass_scaled_mm_azp(x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - azp_adj=azp_adj, - azp=azp, - bias=bias) - return ops.cutlass_scaled_mm(x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - bias=bias) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 02d22a5ca62c0..fc5a3e7fba674 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -6,6 +6,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger +from vllm.model_executor.utils import _make_synced_weight_loader __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", @@ -37,6 +38,18 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): :returns: a torch.nn.parameter """ + # During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + from vllm.platforms import current_platform + if current_platform.is_tpu(): + weight_loader = _make_synced_weight_loader(weight_loader) + self._weight_loader = weight_loader @property diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 77f5c8401424b..d488daf056f1a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -19,7 +19,9 @@ class TpuPlatform(Platform): device_name: str = "tpu" device_type: str = "tpu" dispatch_key: str = "XLA" - supported_quantization: list[str] = ["tpu_int8"] + supported_quantization: list[str] = [ + "tpu_int8", "compressed-tensors", "compressed_tensors" + ] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: