Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Support for Sparse24Bitmask Compressed Models #12097

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null
105 changes: 103 additions & 2 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def test_compressed_tensors_kv_cache(vllm_runner):

@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def _test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="dense"):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)

Expand All @@ -251,7 +254,7 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").format == format
assert sparsity_map.get("Linear").sparsity_structure == "2:4"


Expand Down Expand Up @@ -285,6 +288,72 @@ def check_model(model):
assert output


@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
"channel", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
"channel", "tensor"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
"tensor", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
"tensor", "tensor"),
])
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask")

llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="cutlass is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
"channel", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
"channel", "tensor"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
"tensor", "token"),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
"tensor", "tensor"),
])
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask")

llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
Expand Down Expand Up @@ -343,3 +412,35 @@ def check_model(model):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Cutlass is not yet supported on this GPU type.")
@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)

assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"

llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,16 @@ def get_scheme(
sparsity_scheme=sparsity_scheme):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme = CompressedTensors24(quantized=weight_quant is not None
or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant)
model_compression_config = (None if sparsity_scheme is None
or sparsity_scheme.format == "dense"
else self.config)

scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
Expand Down Expand Up @@ -447,10 +453,21 @@ def supports_cutlass_24(
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity = (sparsity_scheme is not None
and sparsity_scheme.sparsity_structure
== SparsityStructure.TWO_FOUR.value
and sparsity_scheme.format == "dense")
if sparsity_scheme is None:
return False

is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure ==
SparsityStructure.TWO_FOUR.value)

valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value
}

is_valid_sparsity = (is_valid_sparsity_structure
and sparsity_scheme.format in valid_compressors)

if not is_valid_sparsity:
return False

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand All @@ -20,14 +24,24 @@

class CompressedTensors24(CompressedTensorsScheme):

def __init__(self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
):

self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
self.model_compressor = (
ModelCompressor.from_compression_config(model_compression_config)
if model_compression_config is not None else None)
self.do_sparse_decompress = (
self.model_compressor is not None
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -47,6 +61,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)

# parameter to store uncompressed weight
Expand All @@ -57,6 +73,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
if self.do_sparse_decompress:
assert all(partition_size % 8 == 0
for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"

shape = BasevLLMParameter(data=torch.empty(2, 1,
dtype=torch.int64),
weight_loader=weight_loader)
compressed_weight = ModelWeightParameter(
data=torch.empty(sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

bitmask = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)

# Check if quantized, not just 2:4 Sparse
if self.quantized:
Expand Down Expand Up @@ -112,6 +156,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
Comment on lines +161 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we delete layer.compressed and layer.bitmask after decompressing them?

)

# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
Expand Down Expand Up @@ -201,8 +252,55 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:

raise ValueError("Quantization type not supported by Cutlass")

def _decompress_bitmask_compressed_weight(
self, compressed: torch.Tensor, bitmask: torch.Tensor,
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
layer: torch.nn.Module) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor
using the bitmask and return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor
compressed using the sparse-24-bitmask compressor.
:param bitmask: The 2:4 bitmask associated with the compressed weights.
:param layer: The layer whose weights need to be processed
after loading.
:return: The decompressed 2:4 sparse weight tensor.
"""

def check_24(tensor):
new_tensor = tensor.view(-1, 4)
zero_counts = (new_tensor == 0).sum(dim=1)
return (zero_counts >= 2).all().item()
sparsity_compressor = self.model_compressor.sparsity_compressor

def _process_split(bitmask_compressed_weight: torch.Tensor, shape,
bitmask: torch.Tensor) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)

split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [(out, layer.input_size_per_partition)
for out in layer.logical_widths]

if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(compressed=compressed,
shape=(layer.logical_widths[0],
layer.input_size_per_partition),
bitmask=bitmask))
return decompressed
Loading