Skip to content

Commit

Permalink
[ Misc ] Improve Min Capability Checking in compressed-tensors (vll…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat authored and jimpang committed Jul 24, 2024
1 parent 0e3b2e6 commit aab67e7
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:

@classmethod
def get_min_capability(cls) -> int:
return 75
return 70

def get_name(self) -> str:
return "compressed_tensors"
Expand Down Expand Up @@ -85,13 +85,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
def get_config_filenames(cls) -> List[str]:
return []

def _check_gptq_and_marlin_can_run(self):
def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
if capability < min_capability:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -171,7 +172,6 @@ def _get_schema(self, weight_quant: BaseModel,

# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
Expand Down Expand Up @@ -222,10 +222,16 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
raise ValueError(
f"Could not find quantization details for {layer}.")

return self._get_schema(
scheme = self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())

return scheme


class CompressedTensorsLinearMethod(LinearMethodBase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
"""

@abstractmethod
def get_min_capability(self) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError

@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""

def get_min_capability(self) -> int:
# volta and up
return 70

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def get_min_capability(self) -> int:
# ampere + up
return 80

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")

def get_min_capability(self) -> int:
# lovelace and up
return 89

def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme

def get_min_capability(self) -> int:
# turing and up
return 75

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self,
group_size=self.group_size,
is_sym=True)

def get_min_capability(self) -> int:
# ampere and up
return 80

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down

0 comments on commit aab67e7

Please sign in to comment.