From 3b0c8a6ccf4140a7fcc2466b6c750b84e8e20496 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 11 Oct 2024 23:34:11 +0000 Subject: [PATCH 01/63] w8a8 working --- examples/offline_inference_tpu.py | 4 +- vllm/config.py | 2 +- .../schemes/compressed_tensors_w8a8_int8.py | 106 ++++++------------ .../layers/quantization/utils/w8a8_utils.py | 35 ------ 4 files changed, 39 insertions(+), 108 deletions(-) diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index 251629b8027ce..9d004da90a66c 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -19,7 +19,9 @@ # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. -llm = LLM(model="google/gemma-2b", enforce_eager=True) +llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + enforce_eager=False, + max_model_len=1024) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/config.py b/vllm/config.py index 91ba45798b4ba..2791af4b4cc64 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -246,7 +246,7 @@ def _verify_quantization(self) -> None: "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = ["tpu_int8"] + tpu_supported_quantization = ["tpu_int8", "compressed_tensors", "compressed-tensors"] neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 245a35c8783a2..39b36459216e9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,15 +1,14 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch -from torch.nn import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_int8_linear, convert_to_channelwise) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, ModelWeightParameter, @@ -19,6 +18,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool): @@ -31,70 +31,25 @@ def get_min_capability(cls) -> int: # turing and up return 75 - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(self.logical_widths) > 1 - if is_fused_module and self.strategy == QuantizationStrategy.TENSOR: - ws_channelwise = convert_to_channelwise(layer.weight_scale, - self.logical_widths) - layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) - else: - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) - # INPUT SCALE - if self.is_static_input_scheme: - if self.input_symmetric: - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) - layer.input_zero_point = None - else: - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = layer.input_zero_point.to(dtype=torch.int32) - range_max = (layer.input_scale * - (int8_traits.max - azps)).max() - range_min = (layer.input_scale * - (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) - layer.input_scale = Parameter(scale, requires_grad=False) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - - range_min / scale).to(dtype=torch.int32) - layer.input_zero_point = Parameter(azp, requires_grad=False) - - else: - layer.input_scale = None - layer.input_zero_point = None - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md - if not self.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, - keepdim=True, - dtype=torch.int32) - else: - layer.azp_adj = None - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - self.logical_widths = output_partition_sizes + layer.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=self.input_symmetric) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8Int8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT weight = ModelWeightParameter(data=torch.empty( @@ -136,14 +91,23 @@ def create_weights(self, layer: torch.nn.Module, data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader) layer.register_parameter("input_zero_point", input_zero_point) + else: + layer.input_scale = None + layer.input_zero_point = None + layer.azp_adj = None + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_int8_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - input_zero_point=layer.input_zero_point, - azp_adj=layer.azp_adj, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 411af922149fd..af491bafc743f 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -187,41 +187,6 @@ def apply_fp8_linear( return output.to(dtype=input.dtype) -def apply_int8_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - input_zero_point: Optional[torch.Tensor] = None, - azp_adj: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -): - # ops.scaled_int8_quant supports both dynamic and static quant. - # * dynamic, layer.input_scale is None and x_scale computed from x. - # * static, layer.input_scale is scalar and x_scale is input_scale. - symmetric = azp_adj is None - x_q, x_scale, x_zp = ops.scaled_int8_quant(input, - input_scale, - input_zero_point, - symmetric=symmetric) - - if x_zp is not None: - return ops.cutlass_scaled_mm_azp(x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - azp_adj=azp_adj, - azp=x_zp, - bias=bias) - return ops.cutlass_scaled_mm(x_q, - weight, - scale_a=x_scale, - scale_b=weight_scale, - out_dtype=input.dtype, - bias=bias) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, From 36fc1dbd3e83818d1ba2a13141e4f540a104aea9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 11 Oct 2024 23:37:34 +0000 Subject: [PATCH 02/63] format --- examples/offline_inference_tpu.py | 2 +- vllm/config.py | 4 +++- .../schemes/compressed_tensors_w8a8_int8.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index 9d004da90a66c..09c9c92589d21 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -19,7 +19,7 @@ # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. -llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", +llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", enforce_eager=False, max_model_len=1024) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/config.py b/vllm/config.py index 2791af4b4cc64..15a5423aa903f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -246,7 +246,9 @@ def _verify_quantization(self) -> None: "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = ["tpu_int8", "compressed_tensors", "compressed-tensors"] + tpu_supported_quantization = [ + "tpu_int8", "compressed_tensors", "compressed-tensors" + ] neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 39b36459216e9..522ffc34e2250 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -42,10 +42,10 @@ def create_weights(self, layer: torch.nn.Module, is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, input_symmetric=self.input_symmetric) - + kernel_type = choose_scaled_mm_linear_kernel( scaled_mm_linear_kernel_config) - + if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) @@ -102,7 +102,7 @@ def create_weights(self, layer: torch.nn.Module, i_s_param_name="input_scale", i_zp_param_name="input_zero_point", azp_adj_param_name="azp_adj") - + # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -110,4 +110,4 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) \ No newline at end of file + return self.kernel.apply_weights(layer, x, bias) From d83c04c5c2b2a692bbe7d29a7698f1961a078c30 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 11 Oct 2024 23:38:16 +0000 Subject: [PATCH 03/63] added all kernels --- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 67 +++++++++ .../kernels/scaled_mm/__init__.py | 72 ++++++++++ .../quantization/kernels/scaled_mm/cutlass.py | 127 ++++++++++++++++++ .../quantization/kernels/scaled_mm/xla.py | 74 ++++++++++ 4 files changed, 340 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py new file mode 100644 index 0000000000000..575f43d6d18a4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +@dataclass +class ScaledMMLinearLayerConfig: + is_channelwise: bool + is_static_input_scheme: bool + input_symmetric: bool + + +class ScaledMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> Optional[int]: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: ScaledMMLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + i_s_param_name: str, + i_zp_param_name: str, + azp_adj_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + self.i_zp_name = i_zp_param_name + self.azp_adj_name = azp_adj_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + getattr(layer, self.i_zp_name), + getattr(layer, self.azp_adj_name), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000000000..0bb69f66ef727 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import Dict, List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from vllm.platforms import current_platform, PlatformEnum + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} + +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[ScaledMMLinearKernel]: + """ + Choose an ScalledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[ScaledMMLinearKernel]: Chosen kernel. + """ + + if compute_capability is None: + compute_capability = current_platform.get_device_capability() + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if (compute_capability is not None and + kernel.get_min_capability() > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "ScaledMM linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000000000..02ecc656775cf --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,127 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, ScaledMMLinearLayerConfig) + +class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> Optional[int]: + return 75 + + @classmethod + def can_implement(cls, + c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and + not current_platform.is_cpu()): + return False, f"CutlassScaledMM requires running on CUDA or CPU." + + return True + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + if is_fused_module and not self.config.is_channelwise: + weight_scale = getattr(layer, self.w_s_name) + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * + (int8_traits.max - azps)).max() + range_min = (input_scale * + (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale.max(), requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.input_symmetric: + layer.azp_adj = layer.weight.sum(dim=0, + keepdim=True, + dtype=torch.int32) + else: + layer.azp_adj = None + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + if x_zp is not None: + return ops.cutlass_scaled_mm_azp(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=x_zp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py new file mode 100644 index 0000000000000..f8208ade78362 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -0,0 +1,74 @@ +from typing import Optional, Tuple + +import torch +import torch_xla.experimental.xla_quantized_matmul + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig + +class XLAScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> Optional[int]: + return None + + @classmethod + def can_implement(cls, + c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if not current_platform.is_tpu(): + return False, f"ScaledMMXLA requires running on TPU." + + if c.is_static_input_scheme: + return False, f"ScaledMMXLA requires dynamic activation scales." + + if not c.input_symmetric: + return False, "ScaledMMXLA requires symmetric activation scales." + + if not c.is_channelwise: + return False, "ScaledMMXLA requires channelwise weight scales" + + return True, "" + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + + # [out, in] (different than cutlass_scaled_mm) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) + + # WEIGHT SCALE + # XLA kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + + # [out_channel,] (different than cutlass_scaled_mm) + weight_scale = weight_scale.squeeze(-1).to(torch.bfloat16) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, i_azp_adj = self._get_weight_params(layer) + assert i_s is None + assert i_zp is None + assert i_azp_adj is None + + return torch.ops.xla.quantized_matmul( + x, w_q, w_s, quantize_activation=True) From af9d0f4bd3aaa64a766e0679307e1d64be3196e1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 11 Oct 2024 23:53:29 +0000 Subject: [PATCH 04/63] format --- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 19 +++++----- .../kernels/scaled_mm/__init__.py | 33 +++++++++-------- .../quantization/kernels/scaled_mm/cutlass.py | 36 +++++++++---------- .../quantization/kernels/scaled_mm/xla.py | 34 +++++++++--------- 4 files changed, 62 insertions(+), 60 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 575f43d6d18a4..7bda48d8dff9c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -4,6 +4,7 @@ import torch + @dataclass class ScaledMMLinearLayerConfig: is_channelwise: bool @@ -20,17 +21,13 @@ def get_min_capability(cls) -> Optional[int]: @classmethod @abstractmethod - def can_implement(cls, - c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, - c: ScaledMMLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - i_s_param_name: str, - i_zp_param_name: str, - azp_adj_param_name: str) -> None: + def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str, + i_zp_param_name: str, azp_adj_param_name: str) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -52,8 +49,8 @@ def apply_weights(self, def _get_weight_params( self, layer: torch.nn.Module - ) -> Tuple[torch.Tensor, # weight - torch.Tensor, # weight_scale + ) -> Tuple[torch.Tensor, # weight + torch.Tensor, # weight_scale Optional[torch.Tensor], # input_scale, Optional[torch.Tensor], # input_zp Optional[torch.Tensor], # azp_adj diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 0bb69f66ef727..0b3c61cccb0c5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,12 +3,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel) - -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( - ScaledMMLinearKernel, ScaledMMLinearLayerConfig) -from vllm.platforms import current_platform, PlatformEnum +from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { @@ -17,9 +16,11 @@ PlatformEnum.TPU: [XLAScaledMMLinearKernel], } + def choose_scaled_mm_linear_kernel( config: ScaledMMLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[ScaledMMLinearKernel]: + compute_capability: Optional[int] = None +) -> Type[ScaledMMLinearKernel]: """ Choose an ScalledMMLinearKernel that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of @@ -38,7 +39,7 @@ def choose_scaled_mm_linear_kernel( Returns: Type[ScaledMMLinearKernel]: Chosen kernel. """ - + if compute_capability is None: compute_capability = current_platform.get_device_capability() @@ -49,14 +50,18 @@ def choose_scaled_mm_linear_kernel( failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - - if (compute_capability is not None and - kernel.get_min_capability() > compute_capability): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue can_implement, failure_reason = kernel.can_implement(config) if can_implement: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 02ecc656775cf..5a6b2a499140d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -8,30 +8,31 @@ convert_to_channelwise) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): - + @classmethod def get_min_capability(cls) -> Optional[int]: return 75 @classmethod - def can_implement(cls, - c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - - if (not current_platform.is_cuda() and - not current_platform.is_cpu()): - return False, f"CutlassScaledMM requires running on CUDA or CPU." + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." - return True + return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # Cutlass kernels need transposed weight. weight = getattr(layer, self.w_q_name) replace_parameter( - layer, self.w_q_name, + layer, self.w_q_name, torch.nn.Parameter(weight.t().data, requires_grad=False)) # WEIGHT SCALE @@ -61,10 +62,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # reconstruct the ranges int8_traits = torch.iinfo(torch.int8) azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * - (int8_traits.max - azps)).max() - range_min = (input_scale * - (int8_traits.min - azps)).min() + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) @@ -75,9 +74,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # AZP loaded as int8 but used as int32 azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) - replace_parameter( - layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) else: setattr(layer, self.i_s_name, None) @@ -88,7 +86,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # static and dynamic quantization. # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md - if not self.input_symmetric: + if not self.config.input_symmetric: layer.azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) @@ -100,7 +98,7 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) - + # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index f8208ade78362..b6bd4cae0bb02 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,14 +1,16 @@ from typing import Optional, Tuple import torch -import torch_xla.experimental.xla_quantized_matmul +import torch_xla.experimental.xla_quantized_matmul # noqa: F401 from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + class XLAScaledMMLinearKernel(ScaledMMLinearKernel): @@ -17,32 +19,31 @@ def get_min_capability(cls) -> Optional[int]: return None @classmethod - def can_implement(cls, - c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if not current_platform.is_tpu(): - return False, f"ScaledMMXLA requires running on TPU." + return False, "ScaledMMXLA requires running on TPU." if c.is_static_input_scheme: - return False, f"ScaledMMXLA requires dynamic activation scales." - + return False, "ScaledMMXLA requires dynamic activation scales." + if not c.input_symmetric: return False, "ScaledMMXLA requires symmetric activation scales." if not c.is_channelwise: return False, "ScaledMMXLA requires channelwise weight scales" - return True, "" + return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # Cutlass kernels need transposed weight. weight = getattr(layer, self.w_q_name) - + # [out, in] (different than cutlass_scaled_mm) - replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.data, requires_grad=False)) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) # WEIGHT SCALE # XLA kernels support only per-tensor and per-channel. @@ -60,7 +61,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer, self.w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False)) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, @@ -70,5 +70,7 @@ def apply_weights(self, assert i_zp is None assert i_azp_adj is None - return torch.ops.xla.quantized_matmul( - x, w_q, w_s, quantize_activation=True) + return torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + quantize_activation=True) From 0f9fd212a84221bb583cd0df2cc760bc15d8f9d4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 12 Oct 2024 00:04:20 +0000 Subject: [PATCH 05/63] working on cuda --- examples/offline_inference_tpu.py | 5 +++-- .../layers/quantization/kernels/scaled_mm/__init__.py | 3 ++- .../layers/quantization/kernels/scaled_mm/cutlass.py | 2 +- .../layers/quantization/kernels/scaled_mm/xla.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index 09c9c92589d21..7699bec2f2ffd 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -19,7 +19,8 @@ # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. -llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", +# llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", +llm = LLM(model="nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test", enforce_eager=False, max_model_len=1024) outputs = llm.generate(prompts, sampling_params) @@ -27,4 +28,4 @@ prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - assert generated_text.startswith(answer) + # assert generated_text.startswith(answer) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 0b3c61cccb0c5..ae6acd5d9ab42 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -41,7 +41,8 @@ def choose_scaled_mm_linear_kernel( """ if compute_capability is None: - compute_capability = current_platform.get_device_capability() + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] for kernel in _POSSIBLE_KERNELS[current_platform._enum]: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 5a6b2a499140d..07432560d2569 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -40,8 +40,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If we have a fused module (QKV, MLP) with per tensor scales (thus N # scales being passed to the kernel), convert to the per-channel case. is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = getattr(layer, self.w_s_name) weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index b6bd4cae0bb02..6d1ed53f4cefb 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,7 +1,6 @@ from typing import Optional, Tuple import torch -import torch_xla.experimental.xla_quantized_matmul # noqa: F401 from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -70,6 +69,7 @@ def apply_weights(self, assert i_zp is None assert i_azp_adj is None + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 return torch.ops.xla.quantized_matmul(x, w_q, w_s, From 7b3203f5b44e34ddb08466d51debe645e171f220 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 12 Oct 2024 00:08:47 +0000 Subject: [PATCH 06/63] added mixed precision directory --- examples/offline_inference_tpu.py | 5 +- .../layers/quantization/kernels/__init__.py | 72 ------------------- .../{ => mixed_precision}/MPLinearKernel.py | 0 .../kernels/mixed_precision/__init__.py | 72 +++++++++++++++++++ .../kernels/{ => mixed_precision}/machete.py | 0 .../kernels/{ => mixed_precision}/marlin.py | 0 6 files changed, 74 insertions(+), 75 deletions(-) rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/MPLinearKernel.py (100%) create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/machete.py (100%) rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/marlin.py (100%) diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index 7699bec2f2ffd..880f5cc4d0d64 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -19,9 +19,8 @@ # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. -# llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", -llm = LLM(model="nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test", - enforce_eager=False, +llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + enforce_eager=True, max_model_len=1024) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 47591c2aa644e..e69de29bb2d1d 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,72 +0,0 @@ -import os -from typing import List, Optional, Type - -from vllm.model_executor.layers.quantization.kernels.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( - MPLinearKernel, MPLinearLayerConfig) -from vllm.platforms import current_platform - -# in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ - MacheteLinearKernel, - MarlinLinearKernel, -] - - -def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: - """ - Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of - performance. - - Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. - compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. - - Raises: - ValueError: If no kernel can implement the given config. - - Returns: - Type[MPLinearKernel]: Chosen kernel. - """ - if compute_capability is None: - if current_platform is None: - raise ValueError("Cannot determine compute capability") - _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] - - failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): - failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') - continue - - if kernel.get_min_capability() > compute_capability: - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue - - can_implement, failure_reason = kernel.can_implement(config) - if can_implement: - return kernel - else: - failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' - ) - - raise ValueError( - "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 0000000000000..47591c2aa644e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/machete.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/marlin.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py From bf50fa4df37a6cfb524ceca419d09205011a2624 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 12 Oct 2024 00:11:19 +0000 Subject: [PATCH 07/63] formatting --- .../schemes/compressed_tensors_wNa16.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- .../quantization/kernels/mixed_precision/__init__.py | 10 ++++------ .../layers/quantization/kernels/scaled_mm/xla.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index cb65557be8f90..dd18db1ef67e6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -7,7 +7,7 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ActivationOrdering) -from vllm.model_executor.layers.quantization.kernels import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e77191796bd7e..165d15bc552d7 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,7 +10,7 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.kernels import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 47591c2aa644e..e0cd215fb86c2 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,14 +1,12 @@ import os from typing import List, Optional, Type -from vllm.model_executor.layers.quantization.kernels.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( - MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform +from .machete import MacheteLinearKernel +from .marlin import MarlinLinearKernel +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 6d1ed53f4cefb..9ad44ec9cb4a8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -69,7 +69,7 @@ def apply_weights(self, assert i_zp is None assert i_azp_adj is None - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 return torch.ops.xla.quantized_matmul(x, w_q, w_s, From 226ef522f608b5ee072048d1badfcab40c64a7b0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 12 Oct 2024 02:01:28 +0000 Subject: [PATCH 08/63] cache current state - w8a16 running oom --- examples/offline_inference_tpu.py | 11 ++- .../kernels/mixed_precision/MPLinearKernel.py | 2 +- .../kernels/mixed_precision/__init__.py | 35 +++++--- .../kernels/mixed_precision/xla.py | 85 +++++++++++++++++++ .../kernels/scaled_mm/__init__.py | 3 +- 5 files changed, 116 insertions(+), 20 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index 880f5cc4d0d64..b2ce40a3a6057 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -18,10 +18,13 @@ max_tokens=16) # Set `enforce_eager=True` to avoid ahead-of-time compilation. -# In real workloads, `enforace_eager` should be `False`. -llm = LLM(model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - enforce_eager=True, - max_model_len=1024) +# In real workloads, `enforce_eager` should be `False`. +llm = LLM( + # model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + # model="neuralmagic/gemma-2-2b-it-quantized.w8a16", + model="neuralmagic/SmolLM-1.7B-Instruct-quantized.w8a16", + enforce_eager=True, + max_model_len=1024) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index fe50c4930d043..d7fa854d2dcf2 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -23,7 +23,7 @@ class MPLinearKernel(ABC): @classmethod @abstractmethod - def get_min_capability(cls) -> int: + def get_min_capability(cls) -> Optional[int]: raise NotImplementedError @classmethod diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index e0cd215fb86c2..47e6669cec80d 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,17 +1,18 @@ import os -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type -from vllm.platforms import current_platform +from vllm.platforms import PlatformEnum, current_platform from .machete import MacheteLinearKernel from .marlin import MarlinLinearKernel from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig +from .xla import XLAMixedPrecisionLinearKernel # in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ - MacheteLinearKernel, - MarlinLinearKernel, -] +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[MPLinearKernel]]] = { + PlatformEnum.CUDA: [MacheteLinearKernel, MarlinLinearKernel], + PlatformEnum.TPU: [XLAMixedPrecisionLinearKernel] +} def choose_mp_linear_kernel( @@ -39,22 +40,28 @@ def choose_mp_linear_kernel( if current_platform is None: raise ValueError("Cannot determine compute capability") _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ .split(","): failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - if kernel.get_min_capability() > compute_capability: - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue can_implement, failure_reason = kernel.can_implement(config) if can_implement: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py new file mode 100644 index 0000000000000..ca46ee121d5f1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py @@ -0,0 +1,85 @@ +from typing import Optional, Tuple + +import torch + +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +XLA_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] +XLA_SUPPORTED_GROUP_SIZES = [-1] + + +class XLAMixedPrecisionLinearKernel(MPLinearKernel): + """ + XLAMixedPrecisionLinearKernel: WNA16 for TPU + + Kernel definition: + - https://github.com/pytorch/xla/blob/v2.5.0-rc9/torch_xla/experimental/xla_quantized_matmul.py#L78 + + Supported: + - w8a16 symmetric channelwise + + Currently unsupported: + - w8a16 + - w8a16 grouped + - w4a16 + - asymmetric + - activation_reordering + """ + + @classmethod + def get_min_capability(cls) -> Optional[int]: + return None + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if c.zero_points: + return False, "Zero points currently not supported by XLA" + + if c.weight_type not in XLA_SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by XLA"\ + f" , supported types are: {XLA_SUPPORTED_QUANT_TYPES}" + + if c.group_size not in XLA_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "XLA, supported group sizes are: "\ + f"{XLA_SUPPORTED_GROUP_SIZES}" + + if c.has_g_idx: + return False, "Activation reordering is not supported by XLA" + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def transform_w_int8(x): + # reinterpet_cast the packed int32 weights as int8 + # convert to [int, out] -> [out, int] + return x.view(dtype=torch.int8).t() + + def transform_s_channelwise(x): + # convert to [out] + return x.squeeze(-1).to(torch.bfloat16) + + self._transform_param(layer, self.w_q_name, transform_w_int8) + self._transform_param(layer, self.w_s_name, transform_s_channelwise) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + assert w_zp is None and w_gidx is None + + #weight_scale = weight_scale.squeeze(-1).to(torch.bfloat16) + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + return torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + quantize_activation=False) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index ae6acd5d9ab42..01db32e8a8783 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -42,7 +42,8 @@ def choose_scaled_mm_linear_kernel( if compute_capability is None: _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] for kernel in _POSSIBLE_KERNELS[current_platform._enum]: From bb7c74146d9ef44fb73121f2fc00befe8f40450e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 16 Oct 2024 23:57:04 +0000 Subject: [PATCH 09/63] [TPU] Ensure torch._sync(param) is called after param.data.copy_() --- vllm/model_executor/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index d7eec818cbba4..0e66e29f5d377 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -4,6 +4,7 @@ import torch from vllm.utils import seed_everything +from vllm.platforms import current_platform def set_random_seed(seed: int) -> None: @@ -28,4 +29,19 @@ def set_weight_attrs( for key, value in weight_attrs.items(): assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") + + # NOTE(woosuk): For TPU, param.data.copy_(weight) happens lazily, + # which means that the param and weight tensors co-exist until the param + # tensor is used by other operations. This causes excessive memory usage + # during model loading. To avoid this, we sync the param tensor after + # its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + if current_platform.is_tpu() and key == "weight_loader": + original_weight_loader = value + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + value = _synced_weight_loader setattr(weight, key, value) From cf842bdb033a341bc45ac5bdc6d34a00438a7c8a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 17 Oct 2024 00:02:57 +0000 Subject: [PATCH 10/63] yapf --- vllm/model_executor/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 0e66e29f5d377..eaa89be11038b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -3,8 +3,8 @@ import torch -from vllm.utils import seed_everything from vllm.platforms import current_platform +from vllm.utils import seed_everything def set_random_seed(seed: int) -> None: @@ -37,11 +37,14 @@ def set_weight_attrs( # its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu() and key == "weight_loader": - original_weight_loader = value + value = _make_synced_weight_loader(value) + setattr(weight, key, value) - def _synced_weight_loader(param, *args, **kwargs): - original_weight_loader(param, *args, **kwargs) - torch._sync(param) - value = _synced_weight_loader - setattr(weight, key, value) +def _make_synced_weight_loader(original_weight_loader): + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader From 67039bcd5840237c4636ed5820f7b8d7f29b05f0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 17 Oct 2024 00:24:22 +0000 Subject: [PATCH 11/63] [TPU] Correctly profile peak memory usage --- vllm/worker/tpu_worker.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index fe819b9f4b3a8..de6f7ab0072fd 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -133,18 +133,19 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Synchronize before measuring the memory usage. xm.wait_device_ops() - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size = self.cache_config.block_size - block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * - head_size * num_kv_heads) - - # Calculate the TPU KV cache size based on profiling. + # Get the maximum amount of memory used by the model weights and + # intermediate activations. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) - profiled = m["bytes_used"] # Weights + intermediate activations. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_btyes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. From 0695f77c216715472a7b9f016072f0bcf30c9b69 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 17 Oct 2024 00:25:25 +0000 Subject: [PATCH 12/63] Upgrade PyTorch XLA --- Dockerfile.tpu | 2 +- docs/source/getting_started/tpu-installation.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile.tpu b/Dockerfile.tpu index d8f1a42c45177..c1c60584ee246 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240828" +ARG NIGHTLY_DATE="20241017" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 217028839e347..edba209986f6a 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,8 +56,8 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="20240828" - $ export TORCH_VERSION="2.5.0" + $ export DATE="20241017" + $ export TORCH_VERSION="2.6.0" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl From e016e3860be9c6db07e60da24cb21ec60078c1b9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 17:52:48 +0000 Subject: [PATCH 13/63] stash --- .../layers/quantization/kernels/scaled_mm/xla.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 9ad44ec9cb4a8..ab125c38f178a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -68,9 +68,14 @@ def apply_weights(self, assert i_s is None assert i_zp is None assert i_azp_adj is None + assert bias is None, "Bias is not supported for XLA yet." import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - return torch.ops.xla.quantized_matmul(x, - w_q, - w_s, - quantize_activation=True) + return torch.ops.xla.quantized_matmul( + x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) From c84873521c80d79d8dc2fc57031eb483c3a16ed4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 17:55:48 +0000 Subject: [PATCH 14/63] proper merge --- .../layers/quantization/kernels/__init__.py | 74 ----------------- .../kernels/scaled_mm/__init__.py | 83 +++++++++---------- 2 files changed, 39 insertions(+), 118 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 94a3dc2584d6b..e69de29bb2d1d 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,74 +0,0 @@ -from typing import List, Optional, Type - -import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.exllama import ( - ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( - MPLinearKernel, MPLinearLayerConfig) -from vllm.platforms import current_platform - -# in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ - MacheteLinearKernel, - MarlinLinearKernel, - ExllamaLinearKernel, -] - - -def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: - """ - Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of - performance. - - Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. - compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. - - Raises: - ValueError: If no kernel can implement the given config. - - Returns: - Type[MPLinearKernel]: Chosen kernel. - """ - if compute_capability is None: - if current_platform is None: - raise ValueError("Cannot determine compute capability") - _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] - - failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: - failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') - continue - - if kernel.get_min_capability() > compute_capability: - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue - - can_implement, failure_reason = kernel.can_implement(config) - if can_implement: - return kernel - else: - failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' - ) - - raise ValueError( - "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 01db32e8a8783..94a3dc2584d6b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -1,69 +1,64 @@ -import os -from typing import Dict, List, Optional, Type +from typing import List, Optional, Type -from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( - CutlassScaledMMLinearKernel) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearKernel, ScaledMMLinearLayerConfig) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( - XLAScaledMMLinearKernel) -from vllm.platforms import PlatformEnum, current_platform +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.exllama import ( + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { - PlatformEnum.CPU: [CutlassScaledMMLinearKernel], - PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - PlatformEnum.TPU: [XLAScaledMMLinearKernel], -} +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, + ExllamaLinearKernel, +] -def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, - compute_capability: Optional[int] = None -) -> Type[ScaledMMLinearKernel]: +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: """ - Choose an ScalledMMLinearKernel that can implement the given config for the - given compute capability. Attempts to choose the best kernel in terms of - performance. + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer - to be implemented. + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the - compute capability. Defaults to None. + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. Raises: ValueError: If no kernel can implement the given config. Returns: - Type[ScaledMMLinearKernel]: Chosen kernel. + Type[MPLinearKernel]: Chosen kernel. """ - if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc[0] * 10 + _cc[1] + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. - if compute_capability is not None: - kernel_min_capability = kernel.get_min_capability() - if (kernel_min_capability is not None - and kernel_min_capability > compute_capability): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}") - continue + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue can_implement, failure_reason = kernel.can_implement(config) if can_implement: @@ -75,5 +70,5 @@ def choose_scaled_mm_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ - "ScaledMM linear layer. Reasons: \n" + "WNA16 linear layer. Reasons: \n" + '\n'.join(failure_reasons)) From 15399154bd4092183c1865bc82aaddd9b7cbdb31 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 19:49:29 +0000 Subject: [PATCH 15/63] add mixed precision --- .../schemes/compressed_tensors_w8a8_int8.py | 5 +- .../kernels/mixed_precision/__init__.py | 44 +++++----- .../kernels/{ => mixed_precision}/exllama.py | 0 .../kernels/mixed_precision/xla.py | 85 ------------------- .../kernels/scaled_mm/__init__.py | 83 +++++++++--------- 5 files changed, 65 insertions(+), 152 deletions(-) rename vllm/model_executor/layers/quantization/kernels/{ => mixed_precision}/exllama.py (100%) delete mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 3de107c788a81..0d174a54159b4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,14 +1,13 @@ from typing import Callable, List, Optional, Set import torch -from compressed_tensors.quantization import QuantizationStrategy from torch.nn import Parameter +from compressed_tensors.quantization import QuantizationStrategy + from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationStrategy) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) from vllm.model_executor.parameter import (BasevLLMParameter, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 47e6669cec80d..6e88552694636 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,18 +1,19 @@ -import os -from typing import Dict, List, Optional, Type - -from vllm.platforms import PlatformEnum, current_platform +from typing import List, Optional, Type +import vllm.envs as envs +from .exllama import ExllamaLinearKernel from .machete import MacheteLinearKernel from .marlin import MarlinLinearKernel from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig -from .xla import XLAMixedPrecisionLinearKernel + +from vllm.platforms import current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[MPLinearKernel]]] = { - PlatformEnum.CUDA: [MacheteLinearKernel, MarlinLinearKernel], - PlatformEnum.TPU: [XLAMixedPrecisionLinearKernel] -} +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, + ExllamaLinearKernel, +] def choose_mp_linear_kernel( @@ -40,28 +41,21 @@ def choose_mp_linear_kernel( if current_platform is None: raise ValueError("Cannot determine compute capability") _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc[0] * 10 + _cc[1] + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. - if compute_capability is not None: - kernel_min_capability = kernel.get_min_capability() - if (kernel_min_capability is not None - and kernel_min_capability > compute_capability): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}") - continue + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue can_implement, failure_reason = kernel.can_implement(config) if can_implement: diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py similarity index 100% rename from vllm/model_executor/layers/quantization/kernels/exllama.py rename to vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py deleted file mode 100644 index ca46ee121d5f1..0000000000000 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/xla.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional, Tuple - -import torch - -from vllm.scalar_type import scalar_types - -from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig - -XLA_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] -XLA_SUPPORTED_GROUP_SIZES = [-1] - - -class XLAMixedPrecisionLinearKernel(MPLinearKernel): - """ - XLAMixedPrecisionLinearKernel: WNA16 for TPU - - Kernel definition: - - https://github.com/pytorch/xla/blob/v2.5.0-rc9/torch_xla/experimental/xla_quantized_matmul.py#L78 - - Supported: - - w8a16 symmetric channelwise - - Currently unsupported: - - w8a16 - - w8a16 grouped - - w4a16 - - asymmetric - - activation_reordering - """ - - @classmethod - def get_min_capability(cls) -> Optional[int]: - return None - - @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: - - if c.zero_points: - return False, "Zero points currently not supported by XLA" - - if c.weight_type not in XLA_SUPPORTED_QUANT_TYPES: - return False, f"Quant type ({c.weight_type}) not supported by XLA"\ - f" , supported types are: {XLA_SUPPORTED_QUANT_TYPES}" - - if c.group_size not in XLA_SUPPORTED_GROUP_SIZES: - return False, f"Group size ({c.group_size}) not supported by "\ - "XLA, supported group sizes are: "\ - f"{XLA_SUPPORTED_GROUP_SIZES}" - - if c.has_g_idx: - return False, "Activation reordering is not supported by XLA" - - return True, None - - # note assumes that - # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} - # `weight_scale` is: {input_dim = 0, output_dim = 1} - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - def transform_w_int8(x): - # reinterpet_cast the packed int32 weights as int8 - # convert to [int, out] -> [out, int] - return x.view(dtype=torch.int8).t() - - def transform_s_channelwise(x): - # convert to [out] - return x.squeeze(-1).to(torch.bfloat16) - - self._transform_param(layer, self.w_q_name, transform_w_int8) - self._transform_param(layer, self.w_s_name, transform_s_channelwise) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) - assert w_zp is None and w_gidx is None - - #weight_scale = weight_scale.squeeze(-1).to(torch.bfloat16) - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - return torch.ops.xla.quantized_matmul(x, - w_q, - w_s, - quantize_activation=False) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 94a3dc2584d6b..01db32e8a8783 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -1,64 +1,69 @@ -from typing import List, Optional, Type +import os +from typing import Dict, List, Optional, Type -import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.exllama import ( - ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( - MPLinearKernel, MPLinearLayerConfig) -from vllm.platforms import current_platform +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) +from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ - MacheteLinearKernel, - MarlinLinearKernel, - ExllamaLinearKernel, -] +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} -def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None +) -> Type[ScaledMMLinearKernel]: """ - Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of - performance. + Choose an ScalledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. Raises: ValueError: If no kernel can implement the given config. Returns: - Type[MPLinearKernel]: Chosen kernel. + Type[ScaledMMLinearKernel]: Chosen kernel. """ + if compute_capability is None: - if current_platform is None: - raise ValueError("Cannot determine compute capability") _cc = current_platform.get_device_capability() - compute_capability = _cc[0] * 10 + _cc[1] + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue - if kernel.get_min_capability() > compute_capability: - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel.get_min_capability()}, current compute capability " - f"is {compute_capability}") - continue + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue can_implement, failure_reason = kernel.can_implement(config) if can_implement: @@ -70,5 +75,5 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" + "ScaledMM linear layer. Reasons: \n" + '\n'.join(failure_reasons)) From f00412a02c8e57b97b4ef5ad4f974980b9fa11bf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 19:54:44 +0000 Subject: [PATCH 16/63] format --- .../schemes/compressed_tensors_w8a8_int8.py | 2 -- .../kernels/mixed_precision/__init__.py | 4 ++-- .../layers/quantization/kernels/scaled_mm/xla.py | 15 +++++++-------- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 0d174a54159b4..87c0de1abbd67 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,8 +1,6 @@ from typing import Callable, List, Optional, Set import torch -from torch.nn import Parameter - from compressed_tensors.quantization import QuantizationStrategy from vllm.logger import init_logger diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 6e88552694636..1632d419b4614 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,13 +1,13 @@ from typing import List, Optional, Type import vllm.envs as envs +from vllm.platforms import current_platform + from .exllama import ExllamaLinearKernel from .machete import MacheteLinearKernel from .marlin import MarlinLinearKernel from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig -from vllm.platforms import current_platform - # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ab125c38f178a..4fc96f94f9b5d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -71,11 +71,10 @@ def apply_weights(self, assert bias is None, "Bias is not supported for XLA yet." import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - return torch.ops.xla.quantized_matmul( - x, - w_q, - w_s, - zero_point=None, - block_size=-1, - int4_weight=False, - quantize_activation=True) + return torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) From b0a6b70a723865c13a99a746387fa57d65821b42 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 20:40:46 +0000 Subject: [PATCH 17/63] stash --- .../layers/quantization/kernels/scaled_mm/xla.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 4fc96f94f9b5d..081cc2ea4dfcf 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,6 +1,7 @@ from typing import Optional, Tuple import torch +import torch_xla from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -35,12 +36,10 @@ def can_implement( return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) - # [out, in] (different than cutlass_scaled_mm) + weight = getattr(layer, self.w_q_name) replace_parameter(layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)) From 764dda1e9344ae81456c081492d393421bae71d5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 23:02:21 +0000 Subject: [PATCH 18/63] stash --- .../schemes/compressed_tensors_w8a8_int8.py | 1 - .../quantization/kernels/scaled_mm/xla.py | 7 +++-- vllm/model_executor/model_loader/loader.py | 26 +++++++++++++++++++ .../model_loader/weight_utils.py | 1 + vllm/model_executor/parameter.py | 6 +++++ vllm/worker/tpu_worker.py | 18 +++++++++++++ 6 files changed, 56 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 87c0de1abbd67..fe2cc499c9a53 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -58,7 +58,6 @@ def create_weights(self, layer: torch.nn.Module, input_dim=1, output_dim=0, weight_loader=weight_loader) - layer.register_parameter("weight", weight) # WEIGHT SCALE diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 081cc2ea4dfcf..febf2df5d73e3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -2,6 +2,7 @@ import torch import torch_xla +import torch_xla.core.xla_model as xm from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -41,7 +42,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # [out, in] (different than cutlass_scaled_mm) weight = getattr(layer, self.w_q_name) replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(weight.data, requires_grad=False)) + torch.nn.Parameter(weight.data.contiguous(), + requires_grad=False)) # WEIGHT SCALE # XLA kernels support only per-tensor and per-channel. @@ -57,7 +59,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_scale = weight_scale.squeeze(-1).to(torch.bfloat16) replace_parameter( layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + torch.nn.Parameter(weight_scale.data.contiguous(), + requires_grad=False)) def apply_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 813f58339da37..10fe377e67525 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -395,11 +395,27 @@ def load_model(self, *, model_config: ModelConfig, target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: + import torch_xla.core.xla_model as xm + xm.wait_device_ops() + m = xm.get_memory_info() + print(f"before _initialize_model: {m=}") + breakpoint() model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + import torch_xla.core.xla_model as xm + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info() + print(f"after _initialize_model: {m=}") + breakpoint() model.load_weights(self._get_all_weights(model_config, model)) + xm.wait_device_ops() + m = xm.get_memory_info() + print(f"after load_weights: {m=}") + + breakpoint() for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -411,6 +427,16 @@ def load_model(self, *, model_config: ModelConfig, # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + + if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + xm.mark_step() + + xm.wait_device_ops() + m = xm.get_memory_info() + print(f"after process_weights: {m=}") + print(model) + breakpoint() return model.eval() diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0c51314bc90df..231ce753ec47b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -407,6 +407,7 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 param = f.get_tensor(name) + print("name") yield name, param diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 7a6d7c90f34d5..1d92bd734c9a2 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -6,6 +6,9 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger +from vllm.model_executor.utils import _make_synced_weight_loader +from vllm.platforms import current_platform + __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", @@ -37,6 +40,9 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): :returns: a torch.nn.parameter """ + if current_platform.is_tpu(): + weight_loader = _make_synced_weight_loader(weight_loader) + self._weight_loader = weight_loader @property diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index de6f7ab0072fd..af754ab2e1141 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -124,6 +124,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.tensor([], dtype=torch.float32, device=self.device)) for _ in range(num_layers)] + + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + print("\n\nBEFORE") + print(m) + print(f"{total_memory_size=}") + print(f"{profiled=}") + print("\n\n") + + print(f"{self.scheduler_config.max_num_batched_tokens}") self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, @@ -135,9 +147,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Get the maximum amount of memory used by the model weights and # intermediate activations. + m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] profiled = m["peak_bytes_used"] # Weights + intermediate activations. + print("\n\nAFTER") + print(m) + print(f"{total_memory_size=}") + print(f"{profiled=}") + print("\n\n") # Calculate the TPU KV cache size based on profiling. usable_memory_size = int(total_memory_size * From 87b2ae65718682e04d0e211c847cdf4a3c08dc7a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 23:10:10 +0000 Subject: [PATCH 19/63] remove name --- .../layers/quantization/kernels/scaled_mm/xla.py | 6 +++--- vllm/model_executor/model_loader/weight_utils.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index febf2df5d73e3..2f1c82f860f06 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -42,7 +42,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # [out, in] (different than cutlass_scaled_mm) weight = getattr(layer, self.w_q_name) replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(weight.data.contiguous(), + torch.nn.Parameter(weight.data, requires_grad=False)) # WEIGHT SCALE @@ -56,10 +56,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.logical_widths) # [out_channel,] (different than cutlass_scaled_mm) - weight_scale = weight_scale.squeeze(-1).to(torch.bfloat16) + weight_scale = weight_scale.squeeze(-1) replace_parameter( layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data.contiguous(), + torch.nn.Parameter(weight_scale.data, requires_grad=False)) def apply_weights(self, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 231ce753ec47b..0c51314bc90df 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -407,7 +407,6 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 param = f.get_tensor(name) - print("name") yield name, param From e813ff8040ad32572adfb562f709a6f5601d65fb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 23:43:38 +0000 Subject: [PATCH 20/63] revert woosuk change --- vllm/model_executor/model_loader/loader.py | 25 -------------------- vllm/worker/tpu_worker.py | 27 ++++++---------------- 2 files changed, 7 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 10fe377e67525..e73a2bc9eac0d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -395,27 +395,11 @@ def load_model(self, *, model_config: ModelConfig, target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - import torch_xla.core.xla_model as xm - xm.wait_device_ops() - m = xm.get_memory_info() - print(f"before _initialize_model: {m=}") - breakpoint() model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - import torch_xla.core.xla_model as xm - xm.mark_step() - xm.wait_device_ops() - m = xm.get_memory_info() - print(f"after _initialize_model: {m=}") - breakpoint() model.load_weights(self._get_all_weights(model_config, model)) - xm.wait_device_ops() - m = xm.get_memory_info() - print(f"after load_weights: {m=}") - - breakpoint() for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -428,15 +412,6 @@ def load_model(self, *, model_config: ModelConfig, with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if current_platform.is_tpu(): - import torch_xla.core.xla_model as xm - xm.mark_step() - - xm.wait_device_ops() - m = xm.get_memory_info() - print(f"after process_weights: {m=}") - print(model) - breakpoint() return model.eval() diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index af754ab2e1141..bdc0c4a6ececd 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -124,18 +124,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.tensor([], dtype=torch.float32, device=self.device)) for _ in range(num_layers)] - - xm.wait_device_ops() - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - profiled = m["peak_bytes_used"] # Weights + intermediate activations. - print("\n\nBEFORE") - print(m) - print(f"{total_memory_size=}") - print(f"{profiled=}") - print("\n\n") - - print(f"{self.scheduler_config.max_num_batched_tokens}") + self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, @@ -145,21 +134,19 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Synchronize before measuring the memory usage. xm.wait_device_ops() - # Get the maximum amount of memory used by the model weights and - # intermediate activations. + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size = self.cache_config.block_size + block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * + head_size * num_kv_heads) + # Calculate the TPU KV cache size based on profiling. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] - profiled = m["peak_bytes_used"] # Weights + intermediate activations. - print("\n\nAFTER") - print(m) - print(f"{total_memory_size=}") - print(f"{profiled=}") - print("\n\n") # Calculate the TPU KV cache size based on profiling. usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) + profiled = m["bytes_used"] # Weights + intermediate activations. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) dtype_btyes = get_dtype_size(self.cache_dtype) block_size_bytes = (dtype_btyes * self.cache_config.block_size * From 8cfaa1bc45266983efe95ba62bd655b22872ce4f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Oct 2024 23:58:19 +0000 Subject: [PATCH 21/63] format --- .../layers/quantization/kernels/scaled_mm/xla.py | 10 +++------- vllm/model_executor/parameter.py | 1 - 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 2f1c82f860f06..5f41ad9674a8d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,8 +1,6 @@ from typing import Optional, Tuple import torch -import torch_xla -import torch_xla.core.xla_model as xm from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -37,13 +35,12 @@ def can_implement( return True, None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) weight = getattr(layer, self.w_q_name) replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(weight.data, - requires_grad=False)) + torch.nn.Parameter(weight.data, requires_grad=False)) # WEIGHT SCALE # XLA kernels support only per-tensor and per-channel. @@ -59,8 +56,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_scale = weight_scale.squeeze(-1) replace_parameter( layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, - requires_grad=False)) + torch.nn.Parameter(weight_scale.data, requires_grad=False)) def apply_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 1d92bd734c9a2..92dd257da2dbf 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -9,7 +9,6 @@ from vllm.model_executor.utils import _make_synced_weight_loader from vllm.platforms import current_platform - __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", From bbc9741acdec24b3b2ccc719b3cefdbdb078c156 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 21 Oct 2024 02:11:23 +0000 Subject: [PATCH 22/63] update --- vllm/worker/tpu_worker.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index bdc0c4a6ececd..01cc8d66227cb 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -124,7 +124,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.tensor([], dtype=torch.float32, device=self.device)) for _ in range(num_layers)] - self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, @@ -138,8 +137,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: block_size = self.cache_config.block_size block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * head_size * num_kv_heads) - - # Calculate the TPU KV cache size based on profiling. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] @@ -148,9 +145,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: self.cache_config.gpu_memory_utilization) profiled = m["bytes_used"] # Weights + intermediate activations. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size_bytes = (dtype_btyes * self.cache_config.block_size * - num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. From eb3f39e0dd3e77e87478aed6097972d088d7c3c8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 21 Oct 2024 02:11:47 +0000 Subject: [PATCH 23/63] fix nit --- vllm/model_executor/model_loader/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index e73a2bc9eac0d..8e156981c7608 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1147,7 +1147,6 @@ def load_model(self, *, model_config: ModelConfig, lora_config, cache_config) self._load_weights(model_config, model) - return model.eval() From bb2fbe1defd9451cdf59351e1ae568690de1ece8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 21 Oct 2024 02:13:09 +0000 Subject: [PATCH 24/63] update --- examples/offline_inference_tpu.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py index b2ce40a3a6057..251629b8027ce 100644 --- a/examples/offline_inference_tpu.py +++ b/examples/offline_inference_tpu.py @@ -18,16 +18,11 @@ max_tokens=16) # Set `enforce_eager=True` to avoid ahead-of-time compilation. -# In real workloads, `enforce_eager` should be `False`. -llm = LLM( - # model="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - # model="neuralmagic/gemma-2-2b-it-quantized.w8a16", - model="neuralmagic/SmolLM-1.7B-Instruct-quantized.w8a16", - enforce_eager=True, - max_model_len=1024) +# In real workloads, `enforace_eager` should be `False`. +llm = LLM(model="google/gemma-2b", enforce_eager=True) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # assert generated_text.startswith(answer) + assert generated_text.startswith(answer) From 14ccb90bdefaeb555a549e26b50d9ccdca5d4287 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 21 Oct 2024 02:14:04 +0000 Subject: [PATCH 25/63] fix spurious --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8e156981c7608..813f58339da37 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -411,7 +411,6 @@ def load_model(self, *, model_config: ModelConfig, # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - return model.eval() @@ -1147,6 +1146,7 @@ def load_model(self, *, model_config: ModelConfig, lora_config, cache_config) self._load_weights(model_config, model) + return model.eval() From 4092be2925513cb562a56f1c23a7fe7c6ec62244 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 23 Oct 2024 22:05:02 +0000 Subject: [PATCH 26/63] stash branch for brittany --- vllm/executor/ray_tpu_executor.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index d02fecb46f007..9eccd473de4f6 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -81,14 +81,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. override_env = {} - if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: - override_env.update({ - "TPU_CHIPS_PER_HOST_BOUNDS": - os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] - }) - if "TPU_HOST_BOUNDS" in os.environ: - override_env.update( - {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + OVERRIDE_VARS = [ + "TPU_CHIPS_PER_HOST_BOUNDS", + "TPU_CHIPS_PER_PROCESS_BOUNDS", + "TPU_HOST_BOUNDS", + "TPU_SKIP_MDS_QUERY", + "TPU_PROCESS_BOUNDS", + "TPU_ACCELERATOR_TYPE", + "XLA_ALWAYS_ALLREDUCE" + ] + for var in OVERRIDE_VARS: + if var in os.environ: + override_env.update({var: os.environ[var]}) worker = ray.remote( num_cpus=0, From 48aa54b5ae490865de292934c93e855e03b3d1bf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:12:13 +0000 Subject: [PATCH 27/63] revert --- vllm/config.py | 4 ---- vllm/platforms/tpu.py | 3 ++- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 89d9a35754acb..8b824a1fca511 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -556,10 +556,6 @@ def _verify_quantization(self) -> None: "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = [ - "tpu_int8", "compressed_tensors", "compressed-tensors" - ] - neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 77f5c8401424b..91c00cdeb1b22 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -19,7 +19,8 @@ class TpuPlatform(Platform): device_name: str = "tpu" device_type: str = "tpu" dispatch_key: str = "XLA" - supported_quantization: list[str] = ["tpu_int8"] + supported_quantization: list[str] = [ + "tpu_int8", "compressed-tensors", "compressed_tensors"] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: From 4efe9151cfdefaf7ac36fd88d8dc7501d785564c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:13:06 +0000 Subject: [PATCH 28/63] fix --- vllm/executor/ray_tpu_executor.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index be6041930abac..5118c13934f0d 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -73,18 +73,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. override_env = {} - OVERRIDE_VARS = [ - "TPU_CHIPS_PER_HOST_BOUNDS", - "TPU_CHIPS_PER_PROCESS_BOUNDS", - "TPU_HOST_BOUNDS", - "TPU_SKIP_MDS_QUERY", - "TPU_PROCESS_BOUNDS", - "TPU_ACCELERATOR_TYPE", - "XLA_ALWAYS_ALLREDUCE" - ] - for var in OVERRIDE_VARS: - if var in os.environ: - override_env.update({var: os.environ[var]}) + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) worker = ray.remote( num_cpus=0, From e98b79c4e47881f758dd52feb4900050ece6d2c0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:14:12 +0000 Subject: [PATCH 29/63] updated --- vllm/model_executor/layers/quantization/kernels/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 5a896687f211a95cdb40dad2a9ff724833d2a3d6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:16:44 +0000 Subject: [PATCH 30/63] reduce cruft --- .../kernels/mixed_precision/__init__.py | 13 +++++++++---- vllm/platforms/tpu.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 1632d419b4614..69f5572cf8470 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,12 +1,17 @@ from typing import List, Optional, Type import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform -from .exllama import ExllamaLinearKernel -from .machete import MacheteLinearKernel -from .marlin import MarlinLinearKernel -from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 91c00cdeb1b22..d488daf056f1a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -20,7 +20,8 @@ class TpuPlatform(Platform): device_type: str = "tpu" dispatch_key: str = "XLA" supported_quantization: list[str] = [ - "tpu_int8", "compressed-tensors", "compressed_tensors"] + "tpu_int8", "compressed-tensors", "compressed_tensors" + ] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: From 57cbf5c58cd48b98425f7e9db510fe3e85c9bc45 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:17:11 +0000 Subject: [PATCH 31/63] reduce cruft --- .../kernels/mixed_precision/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 69f5572cf8470..ae40895edbb73 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,18 +1,16 @@ from typing import List, Optional, Type import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( - ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( - MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( - MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform - - # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, From 3451c4d3bca4b6f73578e004902422bcf8c390b4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:18:27 +0000 Subject: [PATCH 32/63] updated --- .../quantization/kernels/mixed_precision/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index ae40895edbb73..7d2a1924c84a5 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,13 +1,13 @@ from typing import List, Optional, Type import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( # noqa: E501 MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform From 0c2e62ae4ae148b1c99814f60d0ef00308401002 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:21:59 +0000 Subject: [PATCH 33/63] update comment --- .../quantization/kernels/mixed_precision/__init__.py | 2 +- vllm/model_executor/parameter.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 7d2a1924c84a5..b2c0eef8472bc 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -7,7 +7,7 @@ MacheteLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( # noqa: E501 +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2f19409cdc390..6d68e4dd4e6ab 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -39,6 +39,15 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): :returns: a torch.nn.parameter """ + # During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu(): weight_loader = _make_synced_weight_loader(weight_loader) From 172c9cac0fa36d06516f795a1503c3767bdbe1d9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:24:44 +0000 Subject: [PATCH 34/63] revert spurious change --- .../compressed_tensors/schemes/compressed_tensors_w8a8_int8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index fe2cc499c9a53..87c0de1abbd67 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -58,6 +58,7 @@ def create_weights(self, layer: torch.nn.Module, input_dim=1, output_dim=0, weight_loader=weight_loader) + layer.register_parameter("weight", weight) # WEIGHT SCALE From 938ca8117b8262480e6b1de050aa7df9c4e53099 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:40:48 +0000 Subject: [PATCH 35/63] remove cruft --- .../schemes/compressed_tensors_w8a8_int8.py | 4 ---- .../layers/quantization/kernels/scaled_mm/cutlass.py | 7 +++---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 87c0de1abbd67..0e3f4731775c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -90,10 +90,6 @@ def create_weights(self, layer: torch.nn.Module, data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader) layer.register_parameter("input_zero_point", input_zero_point) - else: - layer.input_scale = None - layer.input_zero_point = None - layer.azp_adj = None self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, w_q_param_name="weight", diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 07432560d2569..b50ccbb956844 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -56,6 +56,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_parameter( layer, self.i_s_name, torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) else: input_zero_point = getattr(layer, self.i_zp_name) @@ -69,7 +70,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: int8_traits.min) replace_parameter( layer, self.i_s_name, - torch.nn.Parameter(scale.max(), requires_grad=False)) + torch.nn.Parameter(scale, requires_grad=False)) # AZP loaded as int8 but used as int32 azp = (int8_traits.min - @@ -87,9 +88,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.config.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, - keepdim=True, - dtype=torch.int32) + layer.azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) else: layer.azp_adj = None From 9e189118c4358e568378423444074d641883a90d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:41:54 +0000 Subject: [PATCH 36/63] cruft reduction --- vllm/model_executor/parameter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 6d68e4dd4e6ab..a5c62f2a4ea12 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -47,7 +47,6 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): # tensor, which is param.data, leading to the redundant memory usage. # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. - # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu(): weight_loader = _make_synced_weight_loader(weight_loader) From 5f58ec73a76acb913b5f3dad2c2d58ecd2578919 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 00:51:05 +0000 Subject: [PATCH 37/63] update docs --- docs/source/features/quantization/supported_hardware.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 988288a82d9bc..f692f912aff51 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -61,7 +61,7 @@ The table below shows the compatibility of various quantization implementations - ✗ - ✗ - ✅︎ - - ✗ + - ✅︎ - ✗ * - FP8 (W8A8) - ✗ From af9f298bc8f291a75d25be76dafcdcde7fc1d428 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 01:11:17 +0000 Subject: [PATCH 38/63] added integration test --- tests/tpu/test_quantization_accuracy.py | 22 +++++++++++++++++++ .../quantization/kernels/scaled_mm/cutlass.py | 4 +++- .../quantization/kernels/scaled_mm/xla.py | 5 +++++ 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 tests/tpu/test_quantization_accuracy.py diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py new file mode 100644 index 0000000000000..ec56480da818b --- /dev/null +++ b/tests/tpu/test_quantization_accuracy.py @@ -0,0 +1,22 @@ +import lm_eval + +MODEL_NAME="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" +MODEL_ARGS = f"pretrained={MODEL_NAME},max_model_len=4096,max_num_seqs=128" +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.58 + +def test_w8a8_single(): + results = lm_eval.simple_evaluate( + model="vllm", + model_args=MODEL_NAME, + tasks="gsm8k", + batch_size="auto", + ) + print(results) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index b50ccbb956844..d40891271068a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -88,7 +88,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.config.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) + layer.azp_adj = layer.weight.sum(dim=0, + keepdim=True, + dtype=torch.int32) else: layer.azp_adj = None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 5f41ad9674a8d..3d9c9b0c05727 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -58,6 +58,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer, self.w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False)) + # Only support symmetric dynamic activation quantization. + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, From 6fe2f623b7bd81ac8ca6eb70398c8f17f4104517 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 01:43:03 +0000 Subject: [PATCH 39/63] updated --- tests/tpu/test_quantization_accuracy.py | 36 +++++++++++++++---- .../kernels/mixed_precision/MPLinearKernel.py | 2 +- .../kernels/scaled_mm/ScaledMMLinearKernel.py | 2 +- .../quantization/kernels/scaled_mm/cutlass.py | 2 +- .../quantization/kernels/scaled_mm/xla.py | 12 +++---- 5 files changed, 38 insertions(+), 16 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index ec56480da818b..75d61b7390300 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -1,21 +1,45 @@ +from dataclasses import dataclass + import lm_eval +import pytest -MODEL_NAME="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" -MODEL_ARGS = f"pretrained={MODEL_NAME},max_model_len=4096,max_num_seqs=128" TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 -EXPECTED_VALUE = 0.58 -def test_w8a8_single(): + +@dataclass +class GSM8KAccuracyTestConfig: + model_name: str + excepted_value: float + + def get_model_args(self) -> str: + return (f"pretrained={self.model_name}," + "max_model_len=4096,max_num_seqs=128") + + +# Accuracy scores measured on GPUs. +ACCURACY_CONFIGS = [ + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + excepted_value=0.76), # no bias + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", + excepted_value=0.66), # bias +] + + +@pytest.mark.parametrize("config", ACCURACY_CONFIGS) +def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): + results = lm_eval.simple_evaluate( model="vllm", - model_args=MODEL_NAME, + model_args=config.get_model_args(), tasks="gsm8k", batch_size="auto", ) - print(results) + EXPECTED_VALUE = config.excepted_value measured_value = results["results"][TASK][FILTER] assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 0c17326f54413..b04612a9b00d9 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -23,7 +23,7 @@ class MPLinearKernel(ABC): @classmethod @abstractmethod - def get_min_capability(cls) -> Optional[int]: + def get_min_capability(cls) -> int: raise NotImplementedError @classmethod diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 7bda48d8dff9c..75cf91f191136 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -16,7 +16,7 @@ class ScaledMMLinearKernel(ABC): @classmethod @abstractmethod - def get_min_capability(cls) -> Optional[int]: + def get_min_capability(cls) -> int: raise NotImplementedError @classmethod diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index d40891271068a..36cb3dafd11a3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -15,7 +15,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> Optional[int]: + def get_min_capability(cls) -> int: return 75 @classmethod diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 3d9c9b0c05727..ec639fe8dc18d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -14,8 +14,10 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> Optional[int]: - return None + def get_min_capability(cls) -> int: + raise NotImplementedError( + "TPU platform does have a concept of compute capability, " + "this method should not be called.") @classmethod def can_implement( @@ -67,11 +69,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - w_q, w_s, i_s, i_zp, i_azp_adj = self._get_weight_params(layer) - assert i_s is None - assert i_zp is None - assert i_azp_adj is None - assert bias is None, "Bias is not supported for XLA yet." + w_q, w_s, _, _, _ = self._get_weight_params(layer) import torch_xla.experimental.xla_quantized_matmul # noqa: F401 return torch.ops.xla.quantized_matmul(x, From f2c0beb4af38c2213c6083c6c57f7f7bcc8b8db5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 01:49:01 +0000 Subject: [PATCH 40/63] Add bias back --- .../layers/quantization/kernels/scaled_mm/xla.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ec639fe8dc18d..ee2a2ed293991 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -72,10 +72,11 @@ def apply_weights(self, w_q, w_s, _, _, _ = self._get_weight_params(layer) import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - return torch.ops.xla.quantized_matmul(x, - w_q, - w_s, - zero_point=None, - block_size=-1, - int4_weight=False, - quantize_activation=True) + out = torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) + return out + bias From 8b2971824f81574d9426d0d4f469f240b3ad2d34 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:00:22 +0000 Subject: [PATCH 41/63] add bias support --- tests/tpu/test_quantization_accuracy.py | 8 ++++---- .../layers/quantization/kernels/scaled_mm/__init__.py | 3 +++ .../layers/quantization/kernels/scaled_mm/xla.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 75d61b7390300..b6368388e7c3c 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -18,11 +18,11 @@ def get_model_args(self) -> str: "max_model_len=4096,max_num_seqs=128") -# Accuracy scores measured on GPUs. +# NOTE(rob): Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ - GSM8KAccuracyTestConfig( - model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - excepted_value=0.76), # no bias + # GSM8KAccuracyTestConfig( + # model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + # excepted_value=0.76), # no bias GSM8KAccuracyTestConfig( model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", excepted_value=0.66), # bias diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 01db32e8a8783..05e5e965e0f35 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -5,6 +5,8 @@ CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( +# TritonScaledMMLinear) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel) from vllm.platforms import PlatformEnum, current_platform @@ -13,6 +15,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + # PlatformEnum.ROCM: [TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ee2a2ed293991..466b9c834bf7e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -79,4 +79,6 @@ def apply_weights(self, block_size=-1, int4_weight=False, quantize_activation=True) - return out + bias + if bias: + return out + bias + return out From 1e2a373b236f3d35214a0fd6dcab615dce090cda Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:26:56 +0000 Subject: [PATCH 42/63] updated --- .../quantization/kernels/scaled_mm/cutlass.py | 5 ++--- .../layers/quantization/kernels/scaled_mm/xla.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 36cb3dafd11a3..c6c6252750726 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -88,9 +88,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.config.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, - keepdim=True, - dtype=torch.int32) + layer.azp_adj = getattr(layer, "w_q_name".sum( + dim=0, keepdim=True, dtype=torch.int32) else: layer.azp_adj = None diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 466b9c834bf7e..7d8100491016a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple +from functorch.experimental.control_flow import cond # noqa: F401 import torch from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -65,6 +66,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) + def _no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + + def _add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + bias + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, @@ -79,6 +87,7 @@ def apply_weights(self, block_size=-1, int4_weight=False, quantize_activation=True) - if bias: - return out + bias - return out + + # Explicitly capture control flow to make dynamo happy. + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + return cond(bias is None, self._no_add_bias, self._add_bias, [out, bias]) From 2a359ef78dacc562c44cfad56abb1f0ce064981a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 02:38:06 +0000 Subject: [PATCH 43/63] stash --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 36cb3dafd11a3..c26436da0bade 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -88,11 +88,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.config.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, - keepdim=True, - dtype=torch.int32) + azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) else: - layer.azp_adj = None + setattr(layer, self.azp_adj_name, None) def apply_weights(self, layer: torch.nn.Module, From 0d4c3fd129ee4bf4cab1301ddbc4d6505efaacfc Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:40:24 +0000 Subject: [PATCH 44/63] fix --- tests/tpu/test_quantization_accuracy.py | 19 ++++++++++--------- .../quantization/kernels/scaled_mm/xla.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index b6368388e7c3c..3444117ae3d86 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -15,14 +15,14 @@ class GSM8KAccuracyTestConfig: def get_model_args(self) -> str: return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=128") + "max_model_len=4096,max_num_seqs=128,enforce_eager=True") # NOTE(rob): Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ - # GSM8KAccuracyTestConfig( - # model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - # excepted_value=0.76), # no bias + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + excepted_value=0.76), # no bias GSM8KAccuracyTestConfig( model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", excepted_value=0.66), # bias @@ -37,10 +37,11 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): model_args=config.get_model_args(), tasks="gsm8k", batch_size="auto", + limit=1, ) - EXPECTED_VALUE = config.excepted_value - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + # EXPECTED_VALUE = config.excepted_value + # measured_value = results["results"][TASK][FILTER] + # assert (measured_value - RTOL < EXPECTED_VALUE + # and measured_value + RTOL > EXPECTED_VALUE + # ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 7d8100491016a..797eb262a5ad6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -90,4 +90,4 @@ def apply_weights(self, # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 - return cond(bias is None, self._no_add_bias, self._add_bias, [out, bias]) + return cond(bias, self._add_bias, self._no_add_bias, [out, bias]) From 57340d285179163fbb8027dbbb37d66c05954ee4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:46:32 +0000 Subject: [PATCH 45/63] update --- tests/tpu/test_quantization_accuracy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 3444117ae3d86..f20dd95b1cfd8 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -15,17 +15,20 @@ class GSM8KAccuracyTestConfig: def get_model_args(self) -> str: return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=128,enforce_eager=True") + "max_model_len=4096,max_num_seqs=128") -# NOTE(rob): Accuracy scores measured on GPUs. +# NOTE: Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ GSM8KAccuracyTestConfig( model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", excepted_value=0.76), # no bias - GSM8KAccuracyTestConfig( - model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - excepted_value=0.66), # bias + # NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, + # so only one of these tests can run in a single call to pytest. As + # a follow up, move this into the LM-EVAL section of the CI. + # GSM8KAccuracyTestConfig( + # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", + # excepted_value=0.66), # bias in QKV layers ] @@ -37,7 +40,6 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): model_args=config.get_model_args(), tasks="gsm8k", batch_size="auto", - limit=1, ) # EXPECTED_VALUE = config.excepted_value From 38291d563e89d62ff8d898069a9024e416ca300a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:54:29 +0000 Subject: [PATCH 46/63] trigger test in CI --- .buildkite/run-tpu-test.sh | 10 +++++++++- tests/tpu/test_quantization_accuracy.py | 10 +++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 770dad6ffa3a1..3f3551b49df96 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,12 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c \ + "python3 -m pip install git+https://github.com/thuml/depyf.git \ + && python3 -m pip install pytest \ + && python3 -m pip install lm_eval[api]==0.4.4 \ + && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ + && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + && python3 /workspace/vllm/tests/tpu/test_compilation.py \ + && python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index f20dd95b1cfd8..509d241099e3b 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -42,8 +42,8 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): batch_size="auto", ) - # EXPECTED_VALUE = config.excepted_value - # measured_value = results["results"][TASK][FILTER] - # assert (measured_value - RTOL < EXPECTED_VALUE - # and measured_value + RTOL > EXPECTED_VALUE - # ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" From ead1e9445cbe099f92a7d8fa9a7d39aa6b19a944 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 02:56:22 +0000 Subject: [PATCH 47/63] fix AZP --- .../layers/quantization/kernels/scaled_mm/__init__.py | 4 +++- .../layers/quantization/kernels/scaled_mm/cutlass.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 05e5e965e0f35..586752d3d34e3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -15,7 +15,9 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - # PlatformEnum.ROCM: [TritonScaledMMLinearKernel], + # TODO(rob): Create TritonScaledMMLinear kernel. ROCM will + # incorrectly attempt to run AZP models if prompted to. + PlatformEnum.ROCM: [CutlassScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2a0f2ef968ab8..4dc6c5841ef16 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -90,6 +90,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj setattr(layer, self.azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False)) else: From cea5e541c572850d0630d4efa8cb580dc9e0600e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 02:56:39 +0000 Subject: [PATCH 48/63] fixed! --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 4dc6c5841ef16..31965ebff86c7 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -90,7 +90,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.is_static_input_scheme: + if self.config.is_static_input_scheme: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case azp_adj = getattr(layer, self.i_zp_name) * azp_adj From 84a5b29bd4971159605a544347c154c3cdd6cdb0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 02:57:43 +0000 Subject: [PATCH 49/63] fix azp adju --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 31965ebff86c7..94e47ecc2584c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -89,7 +89,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) - azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) if self.config.is_static_input_scheme: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case From a1d7b4a1f9d301aa34716cc94e9977ec1506af3d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 02:58:21 +0000 Subject: [PATCH 50/63] make docker command look better on gh --- .buildkite/run-tpu-test.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 3f3551b49df96..d579145ea496a 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,8 +14,9 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c \ - "python3 -m pip install git+https://github.com/thuml/depyf.git \ +docker run --privileged --net host --shm-size=16G -it \ + -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu \ + /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ && python3 -m pip install pytest \ && python3 -m pip install lm_eval[api]==0.4.4 \ && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ From 2b4ecfd4996c4120a46e1b1cf8f72cde5cfcbcea Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 03:08:00 +0000 Subject: [PATCH 51/63] remove torch warnings --- tests/tpu/test_quantization_accuracy.py | 4 ++-- .../layers/quantization/kernels/scaled_mm/cutlass.py | 2 +- .../layers/quantization/kernels/scaled_mm/xla.py | 9 ++++----- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 509d241099e3b..571542639a49e 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -15,7 +15,7 @@ class GSM8KAccuracyTestConfig: def get_model_args(self) -> str: return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=128") + "max_model_len=4096,max_num_seqs=128,tensor_parallel_size=4") # NOTE: Accuracy scores measured on GPUs. @@ -28,7 +28,7 @@ def get_model_args(self) -> str: # a follow up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - # excepted_value=0.66), # bias in QKV layers + # excepted_value=0.66), # bias in QKV layers ] diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 94e47ecc2584c..2c166ab4d6a12 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -94,7 +94,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr(layer, self.azp_adj_name, + setattr(layer, self.azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False)) else: setattr(layer, self.azp_adj_name, None) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 797eb262a5ad6..ccea0c0934263 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple -from functorch.experimental.control_flow import cond # noqa: F401 import torch +from functorch.experimental.control_flow import cond # noqa: F401 from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -66,13 +66,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) - def _no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): return x - def _add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): return x + bias - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, @@ -90,4 +89,4 @@ def apply_weights(self, # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 - return cond(bias, self._add_bias, self._no_add_bias, [out, bias]) + return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) From 186c108de452e0411ea83ee3d120cb24c9235f29 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 03:32:45 +0000 Subject: [PATCH 52/63] stash --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 94e47ecc2584c..fe4b6009756e6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -94,11 +94,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr(layer, self.azp_adj_name, + setattr(layer, self.azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False)) + print(f"{weight.shape=}") + print(f"{azp_adj.shape=}") else: setattr(layer, self.azp_adj_name, None) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, From de773cd6087d28d614fcacdac39a69507c00a2de Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 7 Jan 2025 03:52:02 +0000 Subject: [PATCH 53/63] fix AZP --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index fe4b6009756e6..c7dbe85d5fe81 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -96,8 +96,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: azp_adj = getattr(layer, self.i_zp_name) * azp_adj setattr(layer, self.azp_adj_name, torch.nn.Parameter(azp_adj, requires_grad=False)) - print(f"{weight.shape=}") - print(f"{azp_adj.shape=}") else: setattr(layer, self.azp_adj_name, None) @@ -118,13 +116,16 @@ def apply_weights(self, symmetric=symmetric) if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = i_zp is not None + azp = None if static else x_zp return ops.cutlass_scaled_mm_azp(x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, azp_adj=azp_adj, - azp=x_zp, + azp=azp, bias=bias) return ops.cutlass_scaled_mm(x_q, w_q, From 3a53d7d6b8ab99a6f07d7b11e6009e0eb4709a23 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 14:49:45 +0000 Subject: [PATCH 54/63] merged --- tests/tpu/test_quantization_accuracy.py | 14 +++++++------- .../layers/quantization/kernels/scaled_mm/xla.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 571542639a49e..73b9862c37bf2 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -15,20 +15,20 @@ class GSM8KAccuracyTestConfig: def get_model_args(self) -> str: return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=128,tensor_parallel_size=4") + "max_model_len=4096,max_num_seqs=32") # NOTE: Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ - GSM8KAccuracyTestConfig( - model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - excepted_value=0.76), # no bias + # GSM8KAccuracyTestConfig( + # model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + # excepted_value=0.76), # no bias # NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As # a follow up, move this into the LM-EVAL section of the CI. - # GSM8KAccuracyTestConfig( - # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - # excepted_value=0.66), # bias in QKV layers + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", + excepted_value=0.66), # bias in QKV layers ] diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ccea0c0934263..fe95881998b7e 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Tuple import torch @@ -66,6 +67,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) + # Filter warning for cond usage in apply_weights. It is okay to + # specialize the graph since bias is not dynamic. + warnings.filterwarnings( + "ignore", + message= + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + ) + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): return x @@ -89,4 +98,5 @@ def apply_weights(self, # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + # This throws a lot of warnings. return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) From 0be5f693c6329370b28ae454673185b90010c317 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 14:50:27 +0000 Subject: [PATCH 55/63] added --- vllm/model_executor/layers/quantization/kernels/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From cb69ba72f1d534696e88ee591dabd8a5e72a3b88 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 15:18:24 +0000 Subject: [PATCH 56/63] fix formatting --- .../quantization/kernels/mixed_precision/__init__.py | 8 ++++---- .../layers/quantization/kernels/scaled_mm/xla.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index b2c0eef8472bc..83549870e3f0b 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,13 +1,13 @@ from typing import List, Optional, Type import vllm.envs as envs -from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 MacheteLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 MarlinLinearKernel) -from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 MPLinearKernel, MPLinearLayerConfig) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index fe95881998b7e..92abb56e7819d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -67,8 +67,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) - # Filter warning for cond usage in apply_weights. It is okay to - # specialize the graph since bias is not dynamic. + # Filter warning for cond usage in apply_weights. It is okay + # to specialize the graph since bias is not dynamic. warnings.filterwarnings( "ignore", message= From 3896f6cebb316b3eba948e9488e848fc6c1f0d3e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 16:49:46 +0000 Subject: [PATCH 57/63] remove comment --- vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 92abb56e7819d..5860f0ebae0ee 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -98,5 +98,4 @@ def apply_weights(self, # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 - # This throws a lot of warnings. return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) From 33e1e13ad792836f6d195712ff72dff234427eb0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 16:54:22 +0000 Subject: [PATCH 58/63] formatted --- .../layers/quantization/kernels/scaled_mm/cutlass.py | 1 - .../model_executor/layers/quantization/kernels/scaled_mm/xla.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index c7dbe85d5fe81..2e83a04286a0d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -99,7 +99,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: setattr(layer, self.azp_adj_name, None) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 5860f0ebae0ee..9de668e658826 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -67,7 +67,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) - # Filter warning for cond usage in apply_weights. It is okay + # Filter warning for cond usage in apply_weights. It is okay # to specialize the graph since bias is not dynamic. warnings.filterwarnings( "ignore", From dde72d606a8f26e49395d1141a6f32cf562e3e4f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 7 Jan 2025 17:16:11 +0000 Subject: [PATCH 59/63] add llama to ci --- tests/tpu/test_quantization_accuracy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 73b9862c37bf2..6cd5615c44e1e 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -20,15 +20,15 @@ def get_model_args(self) -> str: # NOTE: Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ - # GSM8KAccuracyTestConfig( - # model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - # excepted_value=0.76), # no bias + GSM8KAccuracyTestConfig( + model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + excepted_value=0.76), # no bias # NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As # a follow up, move this into the LM-EVAL section of the CI. - GSM8KAccuracyTestConfig( - model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - excepted_value=0.66), # bias in QKV layers + # GSM8KAccuracyTestConfig( + # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", + # excepted_value=0.66), # bias in QKV layers ] From db9f79575d20fc0d9b03a8d48f42bd684e4872de Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:01:28 -0500 Subject: [PATCH 60/63] Update supported_hardware.md --- docs/source/features/quantization/supported_hardware.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index f692f912aff51..8d442a9dc163a 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -60,7 +60,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✗ - ✗ - - ✅︎ + - ✗ - ✅︎ - ✗ * - FP8 (W8A8) From 09ad869190f7a2179fcc74ab2800c8c8572943e7 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:01:59 -0500 Subject: [PATCH 61/63] Update supported_hardware.md --- docs/source/features/quantization/supported_hardware.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 8d442a9dc163a..988288a82d9bc 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -60,9 +60,9 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✗ - ✗ - - ✗ - ✅︎ - ✗ + - ✗ * - FP8 (W8A8) - ✗ - ✗ From b74c88a3c24e0e5a60293008e05dbe25808ab75c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 8 Jan 2025 17:17:22 +0000 Subject: [PATCH 62/63] ixed docs build --- vllm/model_executor/parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index a5c62f2a4ea12..fc5a3e7fba674 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -7,7 +7,6 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader -from vllm.platforms import current_platform __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", @@ -47,6 +46,7 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): # tensor, which is param.data, leading to the redundant memory usage. # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. + from vllm.platforms import current_platform if current_platform.is_tpu(): weight_loader = _make_synced_weight_loader(weight_loader) From f353c43f98bbc4e13fb460b537811d0b33cac91f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 8 Jan 2025 17:33:15 +0000 Subject: [PATCH 63/63] fix CI --- .buildkite/run-tpu-test.sh | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 13605a3e97142..a8f021890f742 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -14,4 +14,13 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it \ + -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ + vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ + && python3 -m pip install pytest \ + && python3 -m pip install lm_eval[api]==0.4.4 \ + && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ + && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + && python3 /workspace/vllm/tests/tpu/test_compilation.py \ + && python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ + && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"