Skip to content

Commit

Permalink
[ Misc ] More Cleanup of Marlin (#6359)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Shaw <[email protected]>
  • Loading branch information
robertgshaw2-redhat and robertgshaw2-redhat authored Jul 13, 2024
1 parent 9da4aad commit babf52d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# We use this for fp8, which HF does not support.
#
# Make sure you have lm-eval-harness installed:
# pip install lm-eval==0.4.2
# pip install lm-eval==0.4.3

usage() {
echo``
Expand Down
78 changes: 31 additions & 47 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead

Expand Down Expand Up @@ -145,6 +146,7 @@ def create_weights(
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition

# Normalize group_size
if self.quant_config.group_size != -1:
Expand All @@ -158,32 +160,19 @@ def create_weights(
input_size=input_size,
group_size=group_size)

# Detect sharding of scales/zp

# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None

if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1

is_k_full = input_size_per_partition == input_size

# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# No act-order case

# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True

# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0

# Init buffers
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size

# Quantized weights
qweight = Parameter(
Expand Down Expand Up @@ -268,13 +257,15 @@ def create_weights(
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)

# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
Expand Down Expand Up @@ -312,22 +303,15 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )

output = ops.gptq_marlin_gemm(reshaped_x,
layer.qweight,
layer.scales,
g_idx=layer.g_idx,
perm=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
is_k_full=layer.is_k_full)

if bias is not None:
output.add_(bias) # In-place add

return output.reshape(out_shape)
return apply_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)


def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)


def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)


def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
Expand Down

0 comments on commit babf52d

Please sign in to comment.