-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU][Quantization] TPU W8A8
#11785
Merged
Merged
[TPU][Quantization] TPU W8A8
#11785
Changes from 69 commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
3b0c8a6
w8a8 working
robertgshaw2-redhat 36fc1db
format
robertgshaw2-redhat d83c04c
added all kernels
robertgshaw2-redhat af9d0f4
format
robertgshaw2-redhat 0f9fd21
working on cuda
robertgshaw2-redhat 7b3203f
added mixed precision directory
robertgshaw2-redhat bf50fa4
formatting
robertgshaw2-redhat 226ef52
cache current state - w8a16 running oom
robertgshaw2-redhat bb7c741
[TPU] Ensure torch._sync(param) is called after param.data.copy_()
WoosukKwon cf842bd
yapf
WoosukKwon 67039bc
[TPU] Correctly profile peak memory usage
WoosukKwon 0695f77
Upgrade PyTorch XLA
WoosukKwon 11cf82f
Merge branch 'main' into tpu-peak-mem
WoosukKwon e016e38
stash
robertgshaw2-redhat 717b859
Merge branch 'main' into compressed-tensors-tpu
robertgshaw2-redhat c848735
proper merge
robertgshaw2-redhat 1539915
add mixed precision
robertgshaw2-redhat f00412a
format
robertgshaw2-redhat b0a6b70
stash
robertgshaw2-redhat e812d7e
Merge branch 'tpu-peak-mem' into compressed-tensors-tpu
robertgshaw2-redhat 764dda1
stash
robertgshaw2-redhat 87b2ae6
remove name
robertgshaw2-redhat e813ff8
revert woosuk change
robertgshaw2-redhat 8cfaa1b
format
robertgshaw2-redhat bbc9741
update
robertgshaw2-redhat eb3f39e
fix nit
robertgshaw2-redhat bb2fbe1
update
robertgshaw2-redhat 14ccb90
fix spurious
robertgshaw2-redhat 4092be2
stash branch for brittany
robertgshaw2-redhat 1aaa628
Merge branch 'main' into tpu-w8a8
robertgshaw2-redhat 48aa54b
revert
robertgshaw2-redhat 4efe915
fix
robertgshaw2-redhat e98b79c
updated
robertgshaw2-redhat 5a89668
reduce cruft
robertgshaw2-redhat 57cbf5c
reduce cruft
robertgshaw2-redhat 3451c4d
updated
robertgshaw2-redhat 0c2e62a
update comment
robertgshaw2-redhat 172c9ca
revert spurious change
robertgshaw2-redhat 938ca81
remove cruft
robertgshaw2-redhat 9e18911
cruft reduction
robertgshaw2-redhat 5f58ec7
update docs
robertgshaw2-redhat af9f298
added integration test
robertgshaw2-redhat 6fe2f62
updated
robertgshaw2-redhat f2c0beb
Add bias back
robertgshaw2-redhat 8b29718
add bias support
robertgshaw2-redhat 1e2a373
updated
robertgshaw2-redhat 2a359ef
stash
robertgshaw2-redhat f7e8975
Merge branch 'main' into remove-async-stream
robertgshaw2-redhat 0d4c3fd
fix
robertgshaw2-redhat 57340d2
update
robertgshaw2-redhat 38291d5
trigger test in CI
robertgshaw2-redhat ead1e94
fix AZP
robertgshaw2-redhat cea5e54
fixed!
robertgshaw2-redhat 940ddde
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-redhat 84a5b29
fix azp adju
robertgshaw2-redhat a1d7b4a
make docker command look better on gh
robertgshaw2-redhat 2b4ecfd
remove torch warnings
robertgshaw2-redhat 186c108
stash
robertgshaw2-redhat 7e8598a
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-redhat de773cd
fix AZP
robertgshaw2-redhat 3a53d7d
merged
robertgshaw2-redhat 0be5f69
added
robertgshaw2-redhat cb69ba7
fix formatting
robertgshaw2-redhat 3896f6c
remove comment
robertgshaw2-redhat 33e1e13
formatted
robertgshaw2-redhat dde72d6
add llama to ci
robertgshaw2-redhat d7a9c93
Merge branch 'main' into tpu-w8a8
robertgshaw2-redhat db9f795
Update supported_hardware.md
robertgshaw2-redhat 09ad869
Update supported_hardware.md
robertgshaw2-redhat b74c88a
ixed docs build
robertgshaw2-redhat da4369e
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-redhat 5ddcac2
Merge branch 'main' into tpu-w8a8
robertgshaw2-redhat f353c43
fix CI
robertgshaw2-redhat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from dataclasses import dataclass | ||
|
||
import lm_eval | ||
import pytest | ||
|
||
TASK = "gsm8k" | ||
FILTER = "exact_match,strict-match" | ||
RTOL = 0.03 | ||
|
||
|
||
@dataclass | ||
class GSM8KAccuracyTestConfig: | ||
model_name: str | ||
excepted_value: float | ||
|
||
def get_model_args(self) -> str: | ||
return (f"pretrained={self.model_name}," | ||
"max_model_len=4096,max_num_seqs=32") | ||
|
||
|
||
# NOTE: Accuracy scores measured on GPUs. | ||
ACCURACY_CONFIGS = [ | ||
GSM8KAccuracyTestConfig( | ||
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", | ||
excepted_value=0.76), # no bias | ||
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, | ||
# so only one of these tests can run in a single call to pytest. As | ||
# a follow up, move this into the LM-EVAL section of the CI. | ||
# GSM8KAccuracyTestConfig( | ||
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", | ||
# excepted_value=0.66), # bias in QKV layers | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("config", ACCURACY_CONFIGS) | ||
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): | ||
|
||
results = lm_eval.simple_evaluate( | ||
model="vllm", | ||
model_args=config.get_model_args(), | ||
tasks="gsm8k", | ||
batch_size="auto", | ||
) | ||
|
||
EXPECTED_VALUE = config.excepted_value | ||
measured_value = results["results"][TASK][FILTER] | ||
assert (measured_value - RTOL < EXPECTED_VALUE | ||
and measured_value + RTOL > EXPECTED_VALUE | ||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 0 additions & 74 deletions
74
vllm/model_executor/layers/quantization/kernels/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +0,0 @@ | ||
from typing import List, Optional, Type | ||
|
||
import vllm.envs as envs | ||
from vllm.model_executor.layers.quantization.kernels.exllama import ( | ||
ExllamaLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.machete import ( | ||
MacheteLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.marlin import ( | ||
MarlinLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( | ||
MPLinearKernel, MPLinearLayerConfig) | ||
from vllm.platforms import current_platform | ||
|
||
# in priority/performance order (when available) | ||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
MacheteLinearKernel, | ||
MarlinLinearKernel, | ||
ExllamaLinearKernel, | ||
] | ||
|
||
|
||
def choose_mp_linear_kernel( | ||
config: MPLinearLayerConfig, | ||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
""" | ||
Choose an MPLinearKernel that can implement the given config for the given | ||
compute capability. Attempts to choose the best kernel in terms of | ||
performance. | ||
|
||
Args: | ||
config (MPLinearLayerConfig): Description of the linear layer to be | ||
implemented. | ||
compute_capability (Optional[int], optional): The compute capability of | ||
the target device, if None uses `current_platform` to get the compute | ||
capability. Defaults to None. | ||
|
||
Raises: | ||
ValueError: If no kernel can implement the given config. | ||
|
||
Returns: | ||
Type[MPLinearKernel]: Chosen kernel. | ||
""" | ||
if compute_capability is None: | ||
if current_platform is None: | ||
raise ValueError("Cannot determine compute capability") | ||
_cc = current_platform.get_device_capability() | ||
compute_capability = _cc[0] * 10 + _cc[1] | ||
|
||
failure_reasons = [] | ||
for kernel in _POSSIBLE_KERNELS: | ||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
failure_reasons.append( | ||
f' {kernel.__name__} disabled by environment variable') | ||
continue | ||
|
||
if kernel.get_min_capability() > compute_capability: | ||
failure_reasons.append( | ||
f"{kernel.__name__} requires capability " | ||
f"{kernel.get_min_capability()}, current compute capability " | ||
f"is {compute_capability}") | ||
continue | ||
|
||
can_implement, failure_reason = kernel.can_implement(config) | ||
if can_implement: | ||
return kernel | ||
else: | ||
failure_reasons.append( | ||
f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
) | ||
|
||
raise ValueError( | ||
"Failed to find a kernel that can implement the "\ | ||
"WNA16 linear layer. Reasons: \n" | ||
+ '\n'.join(failure_reasons)) | ||
File renamed without changes.
74 changes: 74 additions & 0 deletions
74
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import List, Optional, Type | ||
|
||
import vllm.envs as envs | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 | ||
ExllamaLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 | ||
MacheteLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 | ||
MarlinLinearKernel) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 | ||
MPLinearKernel, MPLinearLayerConfig) | ||
from vllm.platforms import current_platform | ||
|
||
# in priority/performance order (when available) | ||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
MacheteLinearKernel, | ||
MarlinLinearKernel, | ||
ExllamaLinearKernel, | ||
] | ||
|
||
|
||
def choose_mp_linear_kernel( | ||
config: MPLinearLayerConfig, | ||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
""" | ||
Choose an MPLinearKernel that can implement the given config for the given | ||
compute capability. Attempts to choose the best kernel in terms of | ||
performance. | ||
Args: | ||
config (MPLinearLayerConfig): Description of the linear layer to be | ||
implemented. | ||
compute_capability (Optional[int], optional): The compute capability of | ||
the target device, if None uses `current_platform` to get the compute | ||
capability. Defaults to None. | ||
Raises: | ||
ValueError: If no kernel can implement the given config. | ||
Returns: | ||
Type[MPLinearKernel]: Chosen kernel. | ||
""" | ||
if compute_capability is None: | ||
if current_platform is None: | ||
raise ValueError("Cannot determine compute capability") | ||
_cc = current_platform.get_device_capability() | ||
compute_capability = _cc[0] * 10 + _cc[1] | ||
|
||
failure_reasons = [] | ||
for kernel in _POSSIBLE_KERNELS: | ||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
failure_reasons.append( | ||
f' {kernel.__name__} disabled by environment variable') | ||
continue | ||
|
||
if kernel.get_min_capability() > compute_capability: | ||
failure_reasons.append( | ||
f"{kernel.__name__} requires capability " | ||
f"{kernel.get_min_capability()}, current compute capability " | ||
f"is {compute_capability}") | ||
continue | ||
|
||
can_implement, failure_reason = kernel.can_implement(config) | ||
if can_implement: | ||
return kernel | ||
else: | ||
failure_reasons.append( | ||
f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
) | ||
|
||
raise ValueError( | ||
"Failed to find a kernel that can implement the "\ | ||
"WNA16 linear layer. Reasons: \n" | ||
+ '\n'.join(failure_reasons)) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE for reviewer - this file is not changed, it is just moved