Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Jul 17, 2024
1 parent 9e32a49 commit 36425d2
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
def get_config_filenames(cls) -> List[str]:
return []

def _check_scheme_supported_on_device(self, min_capability: int):
def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < min_capability:
raise RuntimeError("Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")

raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -228,6 +228,7 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":

return scheme


class CompressedTensorsLinearMethod(LinearMethodBase):

def __init__(self, quantization_config: CompressedTensorsConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_min_capability(self) -> int:
Get minimum device capability.
"""
raise NotImplementedError

@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self,
def get_min_capability(self) -> int:
# ampere + up
return 80

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
def get_min_capability(self) -> int:
# ampere and up
return 80

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down

0 comments on commit 36425d2

Please sign in to comment.