mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[FP8] Extend per-token-group quantization support to QuantFP8 (#24342)
Signed-off-by: Tahsin Tunan <tahsintunan@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@ -2,14 +2,25 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def with_triton_mode(fn):
|
||||||
|
"""Temporarily force the Triton fallback path"""
|
||||||
|
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
# TODO(luka): use standalone_compile utility
|
# TODO(luka): use standalone_compile utility
|
||||||
@ -21,78 +32,236 @@ def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
torch._dynamo.config.recompile_limit = 8888
|
def bench_compile(fn: Callable):
|
||||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
# recompile for different shapes
|
||||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
|
||||||
torch_per_token_quant_fp8 = torch.compile(
|
|
||||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=False, # recompile for different shapes
|
|
||||||
)
|
|
||||||
|
|
||||||
# First dim is explicitly dynamic to simulate vLLM usage
|
# First dim is explicitly dynamic to simulate vLLM usage
|
||||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
return with_dyn_arg(fwd, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
def cuda_per_token_quant_fp8(
|
torch._dynamo.config.recompile_limit = 8888
|
||||||
input: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return ops.scaled_fp8_quant(input)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff(batch_size: int, seq_len: int):
|
def calculate_diff(
|
||||||
"""Calculate difference between Triton and CUDA implementations."""
|
batch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""Calculate the difference between Inductor and CUDA implementations."""
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)
|
||||||
|
|
||||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
||||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
|
||||||
|
|
||||||
if torch.allclose(
|
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
|
||||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
||||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
||||||
|
|
||||||
|
out_allclose = lambda o1, o2: torch.allclose(
|
||||||
|
o1.to(torch.float32),
|
||||||
|
o2.to(torch.float32),
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)
|
||||||
|
|
||||||
|
if (
|
||||||
|
out_allclose(cuda_out, torch_out)
|
||||||
|
and scale_allclose(cuda_scale, torch_scale)
|
||||||
|
and out_allclose(cuda_out, torch_eager_out)
|
||||||
|
and scale_allclose(cuda_scale, torch_eager_scale)
|
||||||
|
):
|
||||||
print("✅ All implementations match")
|
print("✅ All implementations match")
|
||||||
else:
|
else:
|
||||||
print("❌ Implementations differ")
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
batch_size_range = [1, 16, 32, 64, 128]
|
configs = []
|
||||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
|
||||||
|
|
||||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
def benchmark_quantization(
|
||||||
triton.testing.Benchmark(
|
batch_size,
|
||||||
x_names=["batch_size", "seq_len"],
|
hidden_size,
|
||||||
x_vals=configs,
|
provider,
|
||||||
line_arg="provider",
|
group_shape: GroupShape,
|
||||||
line_vals=["torch", "cuda"],
|
col_major: bool,
|
||||||
line_names=["Torch", "CUDA"],
|
dtype: torch.dtype,
|
||||||
styles=[("blue", "-"), ("green", "-")],
|
):
|
||||||
ylabel="us",
|
|
||||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
|
||||||
args={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def benchmark_quantization(batch_size, seq_len, provider):
|
|
||||||
dtype = torch.float16
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
||||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
||||||
|
|
||||||
if provider == "torch":
|
if provider == "torch":
|
||||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
|
||||||
elif provider == "cuda":
|
elif provider == "cuda":
|
||||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
fn = lambda: quant_fp8.forward_cuda(x.clone())
|
||||||
|
elif provider == "triton":
|
||||||
|
if not group_shape.is_per_group():
|
||||||
|
# Triton only supported for per-group
|
||||||
|
return 0, 0, 0
|
||||||
|
|
||||||
|
fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(luka) extract to utils
|
||||||
|
def compute_geomean_speedups(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
baseline_col: str,
|
||||||
|
speedup_cols: list[str],
|
||||||
|
groupby_cols: list[str] | None = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Compute geometric mean speedups over a baseline column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe
|
||||||
|
baseline_col: Column to use as baseline
|
||||||
|
speedup_cols: Columns to compute speedups for
|
||||||
|
groupby_cols: Columns to group by. If None, compute over entire df.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with geometric mean speedups
|
||||||
|
"""
|
||||||
|
from scipy.stats import gmean
|
||||||
|
|
||||||
|
def geo_speedup(group: pd.DataFrame) -> pd.Series:
|
||||||
|
ratios = {
|
||||||
|
col: (group[baseline_col] / group[col]).values for col in speedup_cols
|
||||||
|
}
|
||||||
|
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})
|
||||||
|
|
||||||
|
if groupby_cols is None:
|
||||||
|
result = geo_speedup(df).to_frame().T
|
||||||
|
else:
|
||||||
|
result = (
|
||||||
|
df.groupby(groupby_cols)
|
||||||
|
.apply(geo_speedup, include_groups=False)
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
calculate_diff(batch_size=4, seq_len=4096)
|
parser = FlexibleArgumentParser(
|
||||||
benchmark_quantization.run(print_data=True)
|
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
|
||||||
|
)
|
||||||
|
parser.add_argument("-c", "--check", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Batch sizes to benchmark (default: 1,16,32,64,128)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Group sizes for GroupShape(1,N) to benchmark. "
|
||||||
|
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-column-major",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable column-major scales testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args
|
||||||
|
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||||
|
|
||||||
|
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||||
|
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128]
|
||||||
|
|
||||||
|
if args.group_sizes is not None:
|
||||||
|
group_shapes = []
|
||||||
|
for size in args.group_sizes:
|
||||||
|
if size == 0:
|
||||||
|
group_shapes.append(GroupShape.PER_TENSOR)
|
||||||
|
elif size == -1:
|
||||||
|
group_shapes.append(GroupShape.PER_TOKEN)
|
||||||
|
else:
|
||||||
|
group_shapes.append(GroupShape(1, size))
|
||||||
|
else:
|
||||||
|
group_shapes = [
|
||||||
|
GroupShape.PER_TENSOR,
|
||||||
|
GroupShape.PER_TOKEN,
|
||||||
|
GroupShape(1, 64),
|
||||||
|
GroupShape(1, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
column_major_scales = [False] if args.no_column_major else [True, False]
|
||||||
|
|
||||||
|
config_gen = itertools.product(
|
||||||
|
group_shapes,
|
||||||
|
column_major_scales,
|
||||||
|
batch_sizes,
|
||||||
|
hidden_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# filter out column-major scales for non-group, reverse order
|
||||||
|
configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))
|
||||||
|
|
||||||
|
print(f"Running {len(configs)} configurations:")
|
||||||
|
print(f" Hidden sizes: {hidden_sizes}")
|
||||||
|
print(f" Batch sizes: {batch_sizes}")
|
||||||
|
print(f" Group shapes: {[str(g) for g in group_shapes]}")
|
||||||
|
print(f" Column major scales: {column_major_scales}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if args.check:
|
||||||
|
for group_shape in group_shapes:
|
||||||
|
group_size = group_shape[1]
|
||||||
|
print(f"{group_size=}")
|
||||||
|
calculate_diff(
|
||||||
|
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark = triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch", "cuda", "triton"],
|
||||||
|
line_names=["Torch (Compiled)", "CUDA", "Triton"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="QuantFP8 performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)(benchmark_quantization)
|
||||||
|
|
||||||
|
df = benchmark.run(print_data=True, dtype=dtype, return_df=True)
|
||||||
|
|
||||||
|
# Print geomean speedups
|
||||||
|
geo_table_grouped = compute_geomean_speedups(
|
||||||
|
df,
|
||||||
|
baseline_col="Torch (Compiled)",
|
||||||
|
speedup_cols=["CUDA", "Triton"],
|
||||||
|
groupby_cols=["col_major", "group_shape"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Speedup over Torch (Compiled)")
|
||||||
|
print(geo_table_grouped.to_string(index=False))
|
||||||
|
150
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
150
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for QuantFP8 Group Quantization implementation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,hidden_dim,group_size",
|
||||||
|
[
|
||||||
|
(16, 256, 32), # Small
|
||||||
|
(64, 1024, 64), # Medium
|
||||||
|
(128, 2048, 128), # Large
|
||||||
|
(8, 513, 64), # Non-divisible (native only)
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [42])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||||
|
group_size: int, seed: int) -> None:
|
||||||
|
"""Test QuantFP8 group quantization with various configurations.
|
||||||
|
|
||||||
|
Tests both CUDA and native implementations, column-major scales,
|
||||||
|
and verifies consistency between implementations.
|
||||||
|
"""
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
x = torch.randn(
|
||||||
|
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||||
|
expected_num_groups = (hidden_dim + group_size - 1) // group_size
|
||||||
|
is_divisible = hidden_dim % group_size == 0
|
||||||
|
|
||||||
|
group_shape = GroupShape(1, group_size)
|
||||||
|
quant_op = QuantFP8(static=False,
|
||||||
|
group_shape=group_shape,
|
||||||
|
column_major_scales=False)
|
||||||
|
|
||||||
|
# 1. Test native implementation (always available)
|
||||||
|
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||||
|
assert x_quant_native.shape == x.shape
|
||||||
|
assert scales_native.shape == (batch_size, expected_num_groups)
|
||||||
|
|
||||||
|
# 2. Test column-major scales configuration
|
||||||
|
quant_op_col = QuantFP8(static=False,
|
||||||
|
group_shape=group_shape,
|
||||||
|
column_major_scales=True)
|
||||||
|
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||||
|
assert scales_col.shape == (expected_num_groups, batch_size)
|
||||||
|
|
||||||
|
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||||
|
if is_divisible:
|
||||||
|
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
|
||||||
|
assert x_quant_cuda.shape == x.shape
|
||||||
|
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||||
|
|
||||||
|
# Verify CUDA/native consistency
|
||||||
|
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
|
||||||
|
|
||||||
|
# Quantized values should mostly match
|
||||||
|
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||||
|
diff_ratio = diff_count / x_quant_cuda.numel()
|
||||||
|
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", [42])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
# Test with 3D input
|
||||||
|
batch1, batch2, hidden_dim = 4, 8, 512
|
||||||
|
x_3d = torch.randn(
|
||||||
|
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||||
|
|
||||||
|
group_shape = GroupShape(1, group_size)
|
||||||
|
quant_op = QuantFP8(static=False,
|
||||||
|
group_shape=group_shape,
|
||||||
|
column_major_scales=False)
|
||||||
|
|
||||||
|
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||||
|
assert x_quant.shape == x_3d.shape
|
||||||
|
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
|
||||||
|
|
||||||
|
# Test column_major_scales with multi-dim
|
||||||
|
quant_op_col = QuantFP8(static=False,
|
||||||
|
group_shape=group_shape,
|
||||||
|
column_major_scales=True)
|
||||||
|
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||||
|
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
|
||||||
|
|
||||||
|
# Test with 4D input
|
||||||
|
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||||
|
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda") * 8
|
||||||
|
|
||||||
|
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
|
||||||
|
assert x_quant_4d.shape == x_4d.shape
|
||||||
|
assert scales_4d.shape == (batch1, batch2, batch3,
|
||||||
|
hidden_dim // group_size)
|
||||||
|
|
||||||
|
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
|
||||||
|
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
|
||||||
|
batch3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", [42])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
batch_size = 16
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
# Test with single group (group_size >= hidden_dim)
|
||||||
|
x_small = torch.randn(
|
||||||
|
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||||
|
group_shape = GroupShape(1, group_size)
|
||||||
|
quant_op = QuantFP8(static=False,
|
||||||
|
group_shape=group_shape,
|
||||||
|
column_major_scales=False)
|
||||||
|
|
||||||
|
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
|
||||||
|
assert x_quant_small.shape == x_small.shape
|
||||||
|
assert scales_small.shape == (batch_size, 1)
|
||||||
|
|
||||||
|
# Test with zero inputs
|
||||||
|
x_zero = torch.zeros((batch_size, 256),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda")
|
||||||
|
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
|
||||||
|
assert x_quant_zero.shape == x_zero.shape
|
||||||
|
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
|
||||||
|
|
||||||
|
# Test very large values
|
||||||
|
x_large = torch.full((batch_size, 256),
|
||||||
|
1000.0,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda")
|
||||||
|
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
|
||||||
|
assert x_quant_large.shape == x_large.shape
|
||||||
|
# FP8 max is typically 448 or 224, so scales should be > 1
|
||||||
|
assert (scales_large > 1.0).all(), "Large values should have scales > 1"
|
@ -32,9 +32,11 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP)
|
TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
_resize_cache, moe_kernel_quantize_input)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
calculate_tile_tokens_dim)
|
calculate_tile_tokens_dim)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
dequant_mxfp4)
|
dequant_mxfp4)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
@ -23,28 +23,39 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
|||||||
@CustomOp.register("quant_fp8")
|
@CustomOp.register("quant_fp8")
|
||||||
class QuantFP8(CustomOp):
|
class QuantFP8(CustomOp):
|
||||||
"""
|
"""
|
||||||
Quantize input tensor to per-tensor or per-token FP8.
|
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
|
||||||
This CustomOp supports both static and dynamic quantization.
|
This CustomOp supports both static and dynamic quantization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
static: bool,
|
static: bool,
|
||||||
group_shape: GroupShape,
|
group_shape: GroupShape,
|
||||||
num_token_padding: Optional[int] = None):
|
num_token_padding: Optional[int] = None,
|
||||||
|
column_major_scales: bool = False):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param static: static or dynamic quantization
|
:param static: static or dynamic quantization
|
||||||
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR)
|
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||||
:param num_token_padding: Pad the token dimension of output to this size
|
or arbitrary block size)
|
||||||
|
:param num_token_padding: Pad the token dimension of output to this
|
||||||
|
size
|
||||||
|
:param column_major_scales: For group quantization, output scales in
|
||||||
|
column major format
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_token_padding = num_token_padding
|
|
||||||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
|
||||||
assert not static or group_shape == GroupShape.PER_TENSOR, \
|
|
||||||
"Only per-tensor scales supported for static quantization."
|
|
||||||
self.static = static
|
self.static = static
|
||||||
self.group_shape = group_shape
|
self.group_shape = group_shape
|
||||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
self.num_token_padding = num_token_padding
|
||||||
|
self.column_major_scales = column_major_scales
|
||||||
|
|
||||||
|
self.is_group_quant = group_shape.is_per_group()
|
||||||
|
if self.is_group_quant:
|
||||||
|
assert not static, "Group quantization only supports dynamic mode"
|
||||||
|
self.group_size = group_shape.col
|
||||||
|
else:
|
||||||
|
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
||||||
|
assert not static or group_shape == GroupShape.PER_TENSOR, \
|
||||||
|
"Only per-tensor scales supported for static quantization."
|
||||||
|
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
@ -52,11 +63,19 @@ class QuantFP8(CustomOp):
|
|||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
scale_ub: Optional[torch.Tensor] = None,
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.is_group_quant:
|
||||||
|
assert scale is None, "Group quantization is always dynamic"
|
||||||
|
from vllm.model_executor.layers.quantization.utils import fp8_utils
|
||||||
|
return fp8_utils.per_token_group_quant_fp8(
|
||||||
|
x,
|
||||||
|
group_size=self.group_size,
|
||||||
|
column_major_scales=self.column_major_scales,
|
||||||
|
dtype=_FP8_DTYPE)
|
||||||
|
|
||||||
assert (scale is not None) == self.static
|
assert (scale is not None) == self.static
|
||||||
assert scale_ub is None or (not self.static and self.group_shape
|
assert scale_ub is None or (not self.static and self.group_shape
|
||||||
== GroupShape.PER_TOKEN
|
== GroupShape.PER_TOKEN
|
||||||
and scale_ub.numel() == 1)
|
and scale_ub.numel() == 1)
|
||||||
|
|
||||||
return ops.scaled_fp8_quant(
|
return ops.scaled_fp8_quant(
|
||||||
x,
|
x,
|
||||||
scale,
|
scale,
|
||||||
@ -70,6 +89,10 @@ class QuantFP8(CustomOp):
|
|||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
scale_ub: Optional[torch.Tensor] = None,
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if self.is_group_quant:
|
||||||
|
assert scale is None, "Group quantization is always dynamic"
|
||||||
|
return self._quantize_group_native(x)
|
||||||
|
|
||||||
assert (scale is not None) == self.static
|
assert (scale is not None) == self.static
|
||||||
assert scale_ub is None or (not self.static and self.group_shape
|
assert scale_ub is None or (not self.static and self.group_shape
|
||||||
== GroupShape.PER_TOKEN
|
== GroupShape.PER_TOKEN
|
||||||
@ -84,8 +107,7 @@ class QuantFP8(CustomOp):
|
|||||||
else:
|
else:
|
||||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||||
|
|
||||||
scale = x_max / _FP8_MAX
|
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||||
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
|
|
||||||
|
|
||||||
# Even for dynamic per-token scales,
|
# Even for dynamic per-token scales,
|
||||||
# reciprocal performs slightly better than division
|
# reciprocal performs slightly better than division
|
||||||
@ -101,3 +123,34 @@ class QuantFP8(CustomOp):
|
|||||||
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
|
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
|
||||||
|
|
||||||
return out, scale
|
return out, scale
|
||||||
|
|
||||||
|
def _quantize_group_native(
|
||||||
|
self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
orig_shape = x.shape
|
||||||
|
hidden_dim = x.shape[-1]
|
||||||
|
num_groups = (hidden_dim + self.group_size - 1) // self.group_size
|
||||||
|
padded_dim = num_groups * self.group_size
|
||||||
|
|
||||||
|
if padded_dim != hidden_dim:
|
||||||
|
padding = padded_dim - hidden_dim
|
||||||
|
x = F.pad(x, (0, padding), mode='constant', value=0.0)
|
||||||
|
|
||||||
|
x_grouped = x.view(-1, num_groups, self.group_size)
|
||||||
|
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
|
||||||
|
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||||
|
|
||||||
|
x_scaled = x_grouped / scales
|
||||||
|
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||||
|
|
||||||
|
x_quant = x_quant.view(-1, padded_dim)
|
||||||
|
if padded_dim != hidden_dim:
|
||||||
|
x_quant = x_quant[..., :hidden_dim]
|
||||||
|
x_quant = x_quant.view(orig_shape)
|
||||||
|
|
||||||
|
scales = scales.squeeze(-1)
|
||||||
|
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
|
||||||
|
|
||||||
|
if self.column_major_scales:
|
||||||
|
scales = scales.transpose(-2, -1).contiguous()
|
||||||
|
|
||||||
|
return x_quant, scales
|
||||||
|
@ -34,6 +34,15 @@ class GroupShape(_GroupShape):
|
|||||||
PER_TENSOR: ClassVar['GroupShape']
|
PER_TENSOR: ClassVar['GroupShape']
|
||||||
PER_TOKEN: ClassVar['GroupShape']
|
PER_TOKEN: ClassVar['GroupShape']
|
||||||
|
|
||||||
|
def is_per_tensor(self) -> bool:
|
||||||
|
return self.row == -1 and self.col == -1
|
||||||
|
|
||||||
|
def is_per_token(self) -> bool:
|
||||||
|
return self.row == 1 and self.col == -1
|
||||||
|
|
||||||
|
def is_per_group(self) -> bool:
|
||||||
|
return self.row == 1 and self.col >= 1
|
||||||
|
|
||||||
|
|
||||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||||
|
Reference in New Issue
Block a user