Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][Quantization] Add TritonScaledMMLinearKernel since int8 is broken for AMD #12282

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
CutlassScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
# TritonScaledMMLinear)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel)
from vllm.platforms import PlatformEnum, current_platform
Expand All @@ -15,9 +15,7 @@
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
# incorrectly attempt to run AZP models if prompted to.
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional, Tuple

import torch

from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):

@classmethod
def get_min_capability(cls) -> int:
return 75

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
Loading