From e2e0f431290954b0e9fd9ed6d77a27627193d08d Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 17 Dec 2024 05:38:24 +0000 Subject: [PATCH 1/6] Add: Support for Sparse24Bitmask Compressed Models Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 39 +++++-- .../schemes/compressed_tensors_24.py | 106 ++++++++++++++++-- vllm/model_executor/parameter.py | 26 ++++- 3 files changed, 152 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b2fc2360f47f1..96bfde9655961 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -396,10 +396,13 @@ def get_scheme( sparsity_scheme=sparsity_scheme): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24(quantized=weight_quant is not None - or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant) + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=self._get_model_compression_config( + sparsity_scheme), + ) else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore @@ -447,10 +450,17 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value - and sparsity_scheme.format == "dense") + is_valid_sparsity_structure = (sparsity_scheme is not None + and sparsity_scheme.sparsity_structure + == SparsityStructure.TWO_FOUR.value) + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + if not is_valid_sparsity: return False @@ -481,6 +491,19 @@ def supports_cutlass_24( return weight_quant.num_bits == input_quant.num_bits == 8 + def _get_model_compression_config( + self, sparsity_scheme: Optional[SparsityCompressionConfig] = None): + """ + Get the model compressor config from the sparsity scheme + + :param sparsity_scheme: The sparsity scheme + :return: The model compressor config + """ + if sparsity_scheme is None or sparsity_scheme.format == "dense": + return None + + return self.config + class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 21e6fe7a22616..5d12ff07aab44 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,16 +1,21 @@ -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, + BitMaskShapeParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -20,14 +25,24 @@ class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, - quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None): + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) @classmethod def get_min_capability(cls) -> int: @@ -47,6 +62,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -57,6 +74,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + if self.do_sparse_decompress: + assert all( + partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for 2:4 compressed models" + + shape = BitMaskShapeParameter(data=torch.empty( + 2 * len(output_partition_sizes), 1, dtype=torch.uint64), + weight_loader=weight_loader) + compressed = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed) + layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -112,6 +157,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -201,8 +253,42 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") - -def check_24(tensor): - new_tensor = tensor.view(-1, 4) - zero_counts = (new_tensor == 0).sum(dim=1) - return (zero_counts >= 2).all().item() + def _decompress_bitmask_compressed_weight( + self, compressed: torch.Tensor, bitmask: torch.Tensor, + layer: torch.nn.Module) -> torch.Tensor: + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split(bitmask_compressed_weight: torch.Tensor, shape, + bitmask: torch.Tensor) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights = None + split_bitmask = None + split_shape = None + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights is not None: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict(compressed=compressed, + shape=(layer.logical_widths[0], + layer.input_size_per_partition), + bitmask=bitmask)) + return decompressed diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index a9ce8af15d3bb..89d234d08545c 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -11,7 +11,8 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", + "BitMaskShapeParameter" ] logger = init_logger(__name__) @@ -429,3 +430,26 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset + + +class BitMaskShapeParameter(PerTensorScaleParameter): + """ + Parameter class for the shape of the bitmask tensor. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + + Note: Assumes the loaded weight is a 1D tensor + with 2 elements. + """ + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + start_index = shard_id * 2 + param_data[start_index:start_index + 2].copy_(loaded_weight) From c4dd0fa51c52e5e66feafe228032aadf632b2bd5 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 15 Jan 2025 20:58:21 +0000 Subject: [PATCH 2/6] Fix: mypy errors Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 10 +++++++--- .../schemes/compressed_tensors_24.py | 10 +++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 96bfde9655961..c972e1e4c8ba2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -450,9 +450,13 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity_structure = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value) + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + valid_compressors = { CompressionFormat.dense.value, CompressionFormat.sparse_24_bitmask.value diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 5d12ff07aab44..f315eb654187a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from compressed_tensors import CompressionFormat, ModelCompressor @@ -268,9 +268,9 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, ) return sparsity_compressor.decompress_weight(weight_data) - split_weights = None - split_bitmask = None - split_shape = None + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) @@ -278,7 +278,7 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths] - if split_weights is not None: + if split_weights: decompressed_shards = [ _process_split(compressed_weight, shape, bitmask) for compressed_weight, shape, bitmask in zip( From 6f9287776c7870ad642ddb47e29c6cce1ccf0209 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 19:29:14 +0000 Subject: [PATCH 3/6] Removed BitmaskShape Parameter Renamed `compressed` to `compressed_weight` Address review commits from @dsikka Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 20 +++-------- .../schemes/compressed_tensors_24.py | 33 +++++++++---------- vllm/model_executor/parameter.py | 26 +-------------- 3 files changed, 22 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index c972e1e4c8ba2..0298526fea3c5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -396,12 +396,15 @@ def get_scheme( sparsity_scheme=sparsity_scheme): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + scheme = CompressedTensors24( quantized=weight_quant is not None or input_quant is not None, weight_quant=weight_quant, input_quant=input_quant, - model_compression_config=self._get_model_compression_config( - sparsity_scheme), + model_compression_config=model_compression_config, ) else: # Find the quant_scheme @@ -495,19 +498,6 @@ def supports_cutlass_24( return weight_quant.num_bits == input_quant.num_bits == 8 - def _get_model_compression_config( - self, sparsity_scheme: Optional[SparsityCompressionConfig] = None): - """ - Get the model compressor config from the sparsity scheme - - :param sparsity_scheme: The sparsity scheme - :return: The model compressor config - """ - if sparsity_scheme is None or sparsity_scheme.format == "dense": - return None - - return self.config - class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index f315eb654187a..0635ecd3f4c6c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, - BitMaskShapeParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -75,21 +74,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=0, weight_loader=weight_loader) if self.do_sparse_decompress: - assert all( - partition_size % 8 == 0 - for partition_size in output_partition_sizes - ), "All partitions must be divisible by 8 for 2:4 compressed models" - - shape = BitMaskShapeParameter(data=torch.empty( - 2 * len(output_partition_sizes), 1, dtype=torch.uint64), - weight_loader=weight_loader) - compressed = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=self.weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter(data=torch.empty(2, 1, + dtype=torch.int64), + weight_loader=weight_loader) + compressed_weight = ModelWeightParameter( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) bitmask = ModelWeightParameter(data=torch.empty( sum(output_partition_sizes), @@ -100,7 +99,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_loader=weight_loader) layer.register_parameter("shape", shape) - layer.register_parameter("compressed", compressed) + layer.register_parameter("compressed", compressed_weight) layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 89d234d08545c..a9ce8af15d3bb 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -11,8 +11,7 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", - "BitMaskShapeParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" ] logger = init_logger(__name__) @@ -430,26 +429,3 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset - - -class BitMaskShapeParameter(PerTensorScaleParameter): - """ - Parameter class for the shape of the bitmask tensor. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _load_into_shard_id(self, loaded_weight: torch.Tensor, - shard_id: Union[str, int], **kwargs): - """ - Slice the parameter data based on the shard id for - loading. - - Note: Assumes the loaded weight is a 1D tensor - with 2 elements. - """ - param_data = self.data - shard_id = self._shard_id_as_int(shard_id) - start_index = shard_id * 2 - param_data[start_index:start_index + 2].copy_(loaded_weight) From 7e5d8280f8cf8ab1d5af682c5568fee2e4a4c5ae Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:15:28 +0000 Subject: [PATCH 4/6] Add: lm-eval, fp8, int8 tests Signed-off-by: Rahul Tuli --- .../SparseLlama3.1_2of4_fp8_compressed.yaml | 11 +++ tests/quantization/test_compressed_tensors.py | 73 ++++++++++++++++++- .../schemes/compressed_tensors_24.py | 12 +++ 3 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml new file mode 100644 index 0000000000000..2928d75ce4469 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 +model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.6353 + - name: "exact_match,flexible-extract" + value: 0.637 +limit: null +num_fewshot: null diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 0cd86cef0a475..5be1a20479fc4 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -242,7 +242,10 @@ def test_compressed_tensors_kv_cache(vllm_runner): @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") -def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): +def _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -251,7 +254,7 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 - assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").format == format assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -285,6 +288,72 @@ def check_model(model): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.int8 + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 0635ecd3f4c6c..785577a4d4170 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -255,6 +255,18 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: def _decompress_bitmask_compressed_weight( self, compressed: torch.Tensor, bitmask: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor + using the bitmask and return the result. + + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor + compressed using the sparse-24-bitmask compressor. + :param bitmask: The 2:4 bitmask associated with the compressed weights. + :param layer: The layer whose weights need to be processed after loading. + :return: The decompressed 2:4 sparse weight tensor. + """ sparsity_compressor = self.model_compressor.sparsity_compressor From 0de40424e0097186d44d071e75451e25de6e10e5 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:35:44 +0000 Subject: [PATCH 5/6] Add: 2:4 Sparse only compressed test Signed-off-by: Rahul Tuli --- tests/quantization/test_compressed_tensors.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5be1a20479fc4..ac9cdd0aa5ae1 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -412,3 +412,35 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize( + "args_2of4", + [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) +def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): + model = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant is None + assert qkv_proj.scheme.input_quant is None + assert not qkv_proj.scheme.quantized + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + assert sparsity_map.get("Linear").format == "sparse-24-bitmask" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output \ No newline at end of file From 96f376ea4278fb7bfdbb2832794b7809c141596d Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:39:31 +0000 Subject: [PATCH 6/6] Lint Signed-off-by: Rahul Tuli --- .../compressed_tensors/schemes/compressed_tensors_24.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 785577a4d4170..da652e00bde8d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -264,7 +264,8 @@ def _decompress_bitmask_compressed_weight( :param compressed: The 2:4 sparse weight tensor compressed using the sparse-24-bitmask compressor. :param bitmask: The 2:4 bitmask associated with the compressed weights. - :param layer: The layer whose weights need to be processed after loading. + :param layer: The layer whose weights need to be processed + after loading. :return: The decompressed 2:4 sparse weight tensor. """