Skip to content

Commit

Permalink
[Misc] Support register quantization method out-of-tree (#11969)
Browse files Browse the repository at this point in the history
  • Loading branch information
ice-tong authored Jan 19, 2025
1 parent 6d0e3d3 commit 32eb0da
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
117 changes: 117 additions & 0 deletions tests/quantization/test_register_quantization_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests register custom quantization config.
See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
from typing import Any, Dict, List, Optional

import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)


class FakeQuantLinearMethod(UnquantizedLinearMethod):
"""Fake quantization linear method for per-token dynamic quantization."""

def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization method."""
super().__init__()
self.num_bits = num_bits

def apply(self,
layer: "torch.nn.Module",
x: "torch.Tensor",
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
"""Perform fake quantization before the linear layer."""

# Calculate the scales dynamically
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
scales = (max_val - min_val) / (2**self.num_bits - 1)

# Fake quantize the input
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
2**(self.num_bits - 1) - 1)
dequant_x = quant_x * scales

return F.linear(dequant_x, layer.weight, bias)


@register_quantization_config("custom_quant")
class CustomQuantConfig(QuantizationConfig):
"""Custom quantization config for per-token dynamic fake quantization."""

def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization config."""
self.num_bits = num_bits

def get_name(self) -> str:
"""Name of the quantization method."""
return "custom_quant"

def get_supported_act_dtypes(self) -> List["torch.dtype"]:
"""List of supported activation dtypes."""
return [torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method."""
return -1

@staticmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
"""Create a config class from the model's quantization config."""
return CustomQuantConfig(num_bits=config.get("num_bits", 8))

def get_quant_method(self, layer: "torch.nn.Module",
prefix: str) -> Optional["FakeQuantLinearMethod"]:
"""Get the quantize method to use for the quantized layer."""
if isinstance(layer, LinearBase):
return FakeQuantLinearMethod(num_bits=self.num_bits)
return None


def test_register_quantization_config():
"""Test register custom quantization config."""

# The quantization method `custom_quant` should be registered.
assert get_quantization_config("custom_quant") == CustomQuantConfig

# The quantization method `custom_quant` is already exists,
# should raise an error.
with pytest.raises(ValueError):
register_quantization_config("custom_quant")(CustomQuantConfig)


@pytest.mark.parametrize(argnames="model",
argvalues=[
"meta-llama/Meta-Llama-3-8B-Instruct",
])
def test_custom_quant(vllm_runner, model):
"""Test infer with the custom quantization method."""
with vllm_runner(model_name=model,
quantization="custom_quant",
enforce_eager=True) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj

# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,45 @@
"quark"
]

# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}


def register_quantization_config(quantization: str):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from vllm.model_executor.layers.quantization import register_quantization_config
>>> from vllm.model_executor.layers.quantization import get_quantization_config
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
""" # noqa: E501

def _wrapper(quant_config_cls):
if quantization in QUANTIZATION_METHODS:
raise ValueError(
f"The quantization method `{quantization}` is already exists.")
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError("The quantization config must be a subclass of "
"`QuantizationConfig`.")
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
QUANTIZATION_METHODS.append(quantization)
return quant_config_cls

return _wrapper


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
Expand Down Expand Up @@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig,
"quark": QuarkConfig
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

return method_to_config[quantization]

Expand Down

0 comments on commit 32eb0da

Please sign in to comment.