Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] More Cleanup of Marlin #6359

Merged
Merged
Show file tree
Hide file tree
Changes from 124 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
1dfc42d
added
Jun 26, 2024
aa4a9f5
nits
Jun 26, 2024
27f9a03
cleanup
Jun 26, 2024
de7a064
stash
Jun 27, 2024
ec6a833
refactor gptq marlin
robertgshaw2-redhat Jul 3, 2024
966f7be
back out w4a16 act-order compressed tensors
robertgshaw2-redhat Jul 3, 2024
d391f44
back out w4a16 act-order compressed tensors
robertgshaw2-redhat Jul 3, 2024
db075c3
missed
robertgshaw2-redhat Jul 3, 2024
695dc05
formatted'
robertgshaw2-redhat Jul 3, 2024
75c8a11
fix models without gidx
robertgshaw2-redhat Jul 3, 2024
525cf08
format
robertgshaw2-redhat Jul 3, 2024
81f028e
fix test failure
robertgshaw2-redhat Jul 3, 2024
a8fbe89
fix perms not being on gpu
robertgshaw2-redhat Jul 3, 2024
cc843ad
stash
robertgshaw2-redhat Jul 3, 2024
b260c90
stage
robertgshaw2-redhat Jul 3, 2024
c8e97b1
updated
robertgshaw2-redhat Jul 3, 2024
e58063d
nit
robertgshaw2-redhat Jul 3, 2024
383e471
added
robertgshaw2-redhat Jul 3, 2024
865b743
format
robertgshaw2-redhat Jul 3, 2024
9c24525
newline
robertgshaw2-redhat Jul 3, 2024
8b5ac5a
formatting
robertgshaw2-redhat Jul 3, 2024
0e46e4b
working
robertgshaw2-redhat Jul 3, 2024
a47a251
added compressed tensors fp8 to automation
robertgshaw2-redhat Jul 3, 2024
c6be536
missed file
robertgshaw2-redhat Jul 3, 2024
0441171
format
robertgshaw2-redhat Jul 3, 2024
d404f00
remove unnecessary file changes
robertgshaw2-redhat Jul 3, 2024
6569323
restructure quant ops
robertgshaw2-redhat Jul 3, 2024
aa56475
updated to transpose in process_after_loading
robertgshaw2-redhat Jul 3, 2024
d94d07e
updated with varuns suggestion
robertgshaw2-redhat Jul 3, 2024
54308d7
fixed nit
robertgshaw2-redhat Jul 3, 2024
173b93b
name change
robertgshaw2-redhat Jul 3, 2024
afa1ee1
format
robertgshaw2-redhat Jul 3, 2024
5ffe0e4
fixed
robertgshaw2-redhat Jul 3, 2024
4c0e565
fixed tests
robertgshaw2-redhat Jul 3, 2024
ee58d33
Merge branch 'unify-w8a8' into compressed-tensors-fp8
robertgshaw2-redhat Jul 3, 2024
282a038
merge w8a8 unify
robertgshaw2-redhat Jul 3, 2024
a0fd035
fix nit
robertgshaw2-redhat Jul 3, 2024
ba1116b
nits
robertgshaw2-redhat Jul 3, 2024
c1d4375
cleanup
robertgshaw2-redhat Jul 3, 2024
a12bfd5
stash
robertgshaw2-redhat Jul 6, 2024
6aad8f6
Merge branch 'main' into compressed-tensors-fp8
robertgshaw2-redhat Jul 6, 2024
4fc0177
autofp8 working
robertgshaw2-redhat Jul 6, 2024
1d99867
stash
robertgshaw2-redhat Jul 6, 2024
ccee126
stash
robertgshaw2-redhat Jul 6, 2024
0969c67
format
robertgshaw2-redhat Jul 6, 2024
b2eeb84
fix imported marlin_permute_scales
robertgshaw2-redhat Jul 6, 2024
9316f92
format
robertgshaw2-redhat Jul 7, 2024
4ff23c8
added w8a8 to correctness testing
robertgshaw2-redhat Jul 7, 2024
08a8e4e
added testing
robertgshaw2-redhat Jul 7, 2024
4238ac9
format
robertgshaw2-redhat Jul 7, 2024
d1c7517
merged
robertgshaw2-redhat Jul 7, 2024
94d6b35
stash
robertgshaw2-redhat Jul 7, 2024
d48ba9d
readded
robertgshaw2-redhat Jul 7, 2024
0dd2c6a
remove nm-vllm-env
robertgshaw2-redhat Jul 7, 2024
29f40f5
remove old qwen2 moe
robertgshaw2-redhat Jul 7, 2024
ad17c88
readded utils
robertgshaw2-redhat Jul 7, 2024
fd7d825
format
robertgshaw2-redhat Jul 7, 2024
697edfa
Update models-small.txt
robertgshaw2-redhat Jul 7, 2024
e30bd57
gptq marlin tests passing
robertgshaw2-redhat Jul 7, 2024
382d230
add missing files
robertgshaw2-redhat Jul 7, 2024
ba4c7b3
refactoring in progress
robertgshaw2-redhat Jul 7, 2024
0916182
Update models-small.txt
robertgshaw2-redhat Jul 7, 2024
de0242f
stash
robertgshaw2-redhat Jul 7, 2024
9fe4fce
removed lm-eval
robertgshaw2-redhat Jul 7, 2024
c044a86
stash
robertgshaw2-redhat Jul 7, 2024
a5f0aee
remove run
robertgshaw2-redhat Jul 7, 2024
d3299f8
Merge branch 'main' into compressed-tensors-fp8
robertgshaw2-redhat Jul 7, 2024
bcfcd38
added integration test for compressed-tensors-w4-a16
robertgshaw2-redhat Jul 7, 2024
763ab2c
formatting
robertgshaw2-redhat Jul 7, 2024
950de45
Merge branch 'compressed-tensors-fp8' into refactor-gptq-marlin
robertgshaw2-redhat Jul 7, 2024
eb2fdfa
removed
robertgshaw2-redhat Jul 7, 2024
2f49425
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
93812eb
add comment
robertgshaw2-redhat Jul 7, 2024
d4b25cf
Update w8a8_utils.py
robertgshaw2-redhat Jul 7, 2024
48b220e
Update w8a8_utils.py
robertgshaw2-redhat Jul 7, 2024
f1d8ee4
cleanup unnessary changes
robertgshaw2-redhat Jul 7, 2024
cfe27be
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
72b9368
fix gptq marlin
robertgshaw2-redhat Jul 7, 2024
73ae598
formatting
robertgshaw2-redhat Jul 7, 2024
f854c54
cleanup
robertgshaw2-redhat Jul 7, 2024
13d4e93
Merge branch 'main' into refactor-gptq-marlin
robertgshaw2-redhat Jul 7, 2024
4e09688
Update benchmark_marlin.py
robertgshaw2-redhat Jul 7, 2024
db694e0
Update compressed_tensors_wNa16.py
robertgshaw2-redhat Jul 7, 2024
4b2dba2
Update marlin_utils_test.py
robertgshaw2-redhat Jul 7, 2024
9d8d12f
Update test_marlin_gemm.py
robertgshaw2-redhat Jul 7, 2024
54cf4f2
format
robertgshaw2-redhat Jul 7, 2024
7abc2b1
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
ed178d4
formatting
robertgshaw2-redhat Jul 7, 2024
03b11b2
more formatting
robertgshaw2-redhat Jul 7, 2024
e2a5e7a
fix
robertgshaw2-redhat Jul 7, 2024
6f62ada
yapf
robertgshaw2-redhat Jul 7, 2024
933bec3
fixed failing tests
robertgshaw2-redhat Jul 8, 2024
fe6ae88
tweak scores
robertgshaw2-redhat Jul 8, 2024
8285ef6
tweak scores
robertgshaw2-redhat Jul 8, 2024
fcc8925
stash
robertgshaw2-redhat Jul 9, 2024
c0b5d13
format
robertgshaw2-redhat Jul 9, 2024
f6910a5
seems to still be working
robertgshaw2-redhat Jul 9, 2024
84ed30f
stash
robertgshaw2-redhat Jul 11, 2024
62368af
added tests
robertgshaw2-redhat Jul 11, 2024
b618961
seems to be working!
robertgshaw2-redhat Jul 12, 2024
f2755f2
Update build.sh
robertgshaw2-redhat Jul 12, 2024
cd392f5
Merge branch 'main' into act-order
robertgshaw2-redhat Jul 12, 2024
b092079
Merge branch 'act-order' of https://github.com/neuralmagic/nm-vllm in…
robertgshaw2-redhat Jul 12, 2024
5cbed16
cleanup bad merge
robertgshaw2-redhat Jul 12, 2024
054e2db
removed files that should not have been added
robertgshaw2-redhat Jul 12, 2024
7e0b0ec
Update run-lm-eval-gsm-vllm-baseline.sh
robertgshaw2-redhat Jul 12, 2024
bddf9d3
Update test_compressed_tensors.py
robertgshaw2-redhat Jul 12, 2024
ad43c4e
undo
robertgshaw2-redhat Jul 12, 2024
0aa9181
undo bad merge
robertgshaw2-redhat Jul 12, 2024
777e74b
last undo?
robertgshaw2-redhat Jul 12, 2024
77988d3
twas not last
robertgshaw2-redhat Jul 12, 2024
39ed988
cleanup
robertgshaw2-redhat Jul 12, 2024
7d2fff8
stash
robertgshaw2-redhat Jul 12, 2024
2e74b0b
remove more
robertgshaw2-redhat Jul 12, 2024
a845475
fix
robertgshaw2-redhat Jul 12, 2024
2e7bf61
format
robertgshaw2-redhat Jul 12, 2024
18596e2
format
robertgshaw2-redhat Jul 12, 2024
48aae94
more cleanup
robertgshaw2-redhat Jul 12, 2024
bfb3fed
undo ct changes
robertgshaw2-redhat Jul 12, 2024
3288794
cleanup
robertgshaw2-redhat Jul 12, 2024
d6302dd
format
robertgshaw2-redhat Jul 12, 2024
fdf6a90
format
robertgshaw2-redhat Jul 12, 2024
58a4053
final cleanup
robertgshaw2-redhat Jul 12, 2024
bf5e657
formatting
robertgshaw2-redhat Jul 12, 2024
8d06c70
fixed
robertgshaw2-redhat Jul 12, 2024
0f87813
updated
robertgshaw2-redhat Jul 12, 2024
3400f38
Merge branch 'main' into use-shared-gptq-marlin
robertgshaw2-redhat Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 31 additions & 47 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead

Expand Down Expand Up @@ -145,6 +146,7 @@ def create_weights(
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition

# Normalize group_size
if self.quant_config.group_size != -1:
Expand All @@ -158,32 +160,19 @@ def create_weights(
input_size=input_size,
group_size=group_size)

# Detect sharding of scales/zp

# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None

if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1

is_k_full = input_size_per_partition == input_size

# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
robertgshaw2-redhat marked this conversation as resolved.
Show resolved Hide resolved
scales_and_zp_size = input_size // group_size
else:
# No act-order case

# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True

# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0

# Init buffers
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size

# Quantized weights
qweight = Parameter(
Expand Down Expand Up @@ -268,13 +257,15 @@ def create_weights(
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)

# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
Expand Down Expand Up @@ -312,22 +303,15 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )

output = ops.gptq_marlin_gemm(reshaped_x,
layer.qweight,
layer.scales,
g_idx=layer.g_idx,
perm=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
is_k_full=layer.is_k_full)

if bias is not None:
output.add_(bias) # In-place add

return output.reshape(out_shape)
return apply_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)


def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (act_order and not is_row_parallel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition here is correct, but it would be better to pass is_k_full for the case where act_order == False.
Like this:
return (not act_order) or (act_order and not is_row_parallel)



def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
robertgshaw2-redhat marked this conversation as resolved.
Show resolved Hide resolved


def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
Expand Down
Loading