diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 586752d3d34e3..4824a11804163 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -5,8 +5,8 @@ CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, ScaledMMLinearLayerConfig) -# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( -# TritonScaledMMLinear) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( + TritonScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel) from vllm.platforms import PlatformEnum, current_platform @@ -15,9 +15,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - # TODO(rob): Create TritonScaledMMLinear kernel. ROCM will - # incorrectly attempt to run AZP models if prompted to. - PlatformEnum.ROCM: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py new file mode 100644 index 0000000000000..97ec8cb0500d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -0,0 +1,38 @@ +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.") + if not c.input_symmetric: + return (False, + "TritonScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias)