From 9e8bad6c020730acd3c75b3487704bf240c8158d Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 21 Jan 2025 13:01:04 -0600 Subject: [PATCH 1/4] TritonScaledMMLinearKernel implementation Signed-off-by: Randall Smith --- .../kernels/scaled_mm/__init__.py | 8 ++-- .../quantization/kernels/scaled_mm/triton.py | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 586752d3d34e3..4824a11804163 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -5,8 +5,8 @@ CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 ScaledMMLinearKernel, ScaledMMLinearLayerConfig) -# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( -# TritonScaledMMLinear) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( + TritonScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel) from vllm.platforms import PlatformEnum, current_platform @@ -15,9 +15,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - # TODO(rob): Create TritonScaledMMLinear kernel. ROCM will - # incorrectly attempt to run AZP models if prompted to. - PlatformEnum.ROCM: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py new file mode 100644 index 0000000000000..97ec8cb0500d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -0,0 +1,38 @@ +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.") + if not c.input_symmetric: + return (False, + "TritonScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias) From daf9a719dfdc0e9b417d35627e5ba302f283ddd3 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 22 Jan 2025 18:31:58 +0000 Subject: [PATCH 2/4] Add regression test for rocm w8a8 Signed-off-by: Randall Smith --- tests/kernels/test_triton_scaled_mm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/test_triton_scaled_mm.py index 8e96a2f70d751..b5b760d60e1c7 100644 --- a/tests/kernels/test_triton_scaled_mm.py +++ b/tests/kernels/test_triton_scaled_mm.py @@ -8,6 +8,7 @@ import pytest import torch +from tests.models.utils import check_logprobs_close from vllm.platforms import current_platform device = "cuda" @@ -39,6 +40,23 @@ def get_8bit_types(): return types +# This test is to check regressions for int8 support on ROCm. +@pytest.mark.parametrize("model_path", [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", +]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="Should only run on ROCm") +def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, + max_tokens, num_logprobs): + dtype = "bfloat16" + + with vllm_runner(model_path, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + @pytest.mark.parametrize("M", [1, 33, 64, 512]) @pytest.mark.parametrize("N", [256, 971, 20486]) @pytest.mark.parametrize("K", [128, 496, 1024]) From 9c11d5c1b0acd31bb87a1afa7446f1fe962f76da Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 22 Jan 2025 18:50:37 +0000 Subject: [PATCH 3/4] remote unused import Signed-off-by: Randall Smith --- tests/kernels/test_triton_scaled_mm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/test_triton_scaled_mm.py index b5b760d60e1c7..fd22a42e85342 100644 --- a/tests/kernels/test_triton_scaled_mm.py +++ b/tests/kernels/test_triton_scaled_mm.py @@ -8,7 +8,6 @@ import pytest import torch -from tests.models.utils import check_logprobs_close from vllm.platforms import current_platform device = "cuda" From 4e4d633e7774347ccebe6c8c6b4c1e88a994c1c6 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 22 Jan 2025 18:57:19 +0000 Subject: [PATCH 4/4] ruff Signed-off-by: Randall Smith --- tests/kernels/test_triton_scaled_mm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/test_triton_scaled_mm.py index fd22a42e85342..a5aab3c2ea4b0 100644 --- a/tests/kernels/test_triton_scaled_mm.py +++ b/tests/kernels/test_triton_scaled_mm.py @@ -52,8 +52,8 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, dtype = "bfloat16" with vllm_runner(model_path, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, + num_logprobs) @pytest.mark.parametrize("M", [1, 33, 64, 512])