From 2d43e96289998099cf8e9811f981e8dab57d9ddb Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 30 Jul 2024 16:37:01 -0400 Subject: [PATCH] [Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842) --- tests/quantization/test_fp8.py | 2 +- vllm/_custom_ops.py | 24 ++++++++++--------- .../layers/quantization/utils/w8a8_utils.py | 5 ++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index ad92f1f189f65..a020f7bf37262 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -123,7 +123,7 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Padding - y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17) + y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17) assert y.shape[0] == 17 assert torch.allclose( ref_y, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ad9f01be6ddd4..6ca667eb85640 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, - batch_dim_padding: Optional[int] = None, + num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -317,7 +317,7 @@ def scaled_fp8_quant( This function supports both static and dynamic quantization: If you provide the scale, it will use static scaling and if you omit it, the scale will be determined dynamically. The function also allows - optional padding of the output tensor for downstream kernels that + optional padding of the output tensors for downstream kernels that will benefit from padding. Args: @@ -325,7 +325,7 @@ def scaled_fp8_quant( scale: Optional scaling factor for the FP8 quantization scale_ub: Optional upper bound for scaling factor in dynamic per token case - batch_dim_padding: If specified, pad the first dimension + num_token_padding: If specified, pad the first dimension of the output to at least this value. use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. @@ -334,16 +334,16 @@ def scaled_fp8_quant( Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ - if batch_dim_padding: - shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) - output = torch.empty(shape, - device=input.device, - dtype=torch.float8_e4m3fn) - else: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + # This code assumes batch_dim and num_tokens are flattened + assert (input.ndim == 2) + shape = input.shape + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn) + if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((input.numel() // input.shape[-1], 1), + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( @@ -352,6 +352,8 @@ def scaled_fp8_quant( scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: + # num_token_padding not implemented for this case + assert (scale.numel() == 1 or num_token_padding is None) torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20100c76bd690..dbe86902853cd 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -139,7 +139,7 @@ def apply_fp8_linear( qinput, x_scale = ops.scaled_fp8_quant( input, input_scale, - batch_dim_padding=17, + num_token_padding=17, use_per_token_if_dynamic=use_per_token_if_dynamic) per_tensor_weights = (weight_scale.numel() == 1) @@ -177,8 +177,9 @@ def apply_fp8_linear( output, _ = torch._scaled_mm(qinput, weight, out_dtype=torch.float32) - # Unpad (undo batch_dim_padding) + # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) # DQ # C = sw * sx * (X * W) + bias