mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842)
This commit is contained in:
committed by
GitHub
parent
052b6f8ca4
commit
d7a299edaa
@ -123,7 +123,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Padding
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
||||
assert y.shape[0] == 17
|
||||
assert torch.allclose(
|
||||
ref_y,
|
||||
|
@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
batch_dim_padding: Optional[int] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -317,7 +317,7 @@ def scaled_fp8_quant(
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensor for downstream kernels that
|
||||
optional padding of the output tensors for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
@ -325,7 +325,7 @@ def scaled_fp8_quant(
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
batch_dim_padding: If specified, pad the first dimension
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
@ -334,16 +334,16 @@ def scaled_fp8_quant(
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
if batch_dim_padding:
|
||||
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
|
||||
output = torch.empty(shape,
|
||||
device=input.device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
else:
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
scale = torch.empty((shape[0], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||
@ -352,6 +352,8 @@ def scaled_fp8_quant(
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
assert (scale.numel() == 1 or num_token_padding is None)
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
@ -139,7 +139,7 @@ def apply_fp8_linear(
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input,
|
||||
input_scale,
|
||||
batch_dim_padding=17,
|
||||
num_token_padding=17,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
@ -177,8 +177,9 @@ def apply_fp8_linear(
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=torch.float32)
|
||||
# Unpad (undo batch_dim_padding)
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
|
Reference in New Issue
Block a user