[Transform] [Quantization] Add QuTLASS support to vLLM (#24440)

Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Andrei Panferov <andrei@panferov.org>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Roberto L. Castro
2025-10-10 18:43:40 +02:00
committed by GitHub
parent 8d2b8c0ff2
commit 96ad65b7fe
12 changed files with 1848 additions and 1 deletions

View File

@ -834,6 +834,8 @@ steps:
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60

View File

@ -1007,6 +1007,7 @@ endif()
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/qutlass.cmake)
# vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_projects/vllm_flash_attn.cmake)

View File

@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"mxfp4": dict(no_a_quant=False, enabled=True),
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_mxfp4(
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
b, forward_hadamard_matrix, method="abs_max"
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
return weight_hf_e2m1, weight_hf_scale_block
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
b, forward_hadamard_matrix, device
)
alpha = torch.tensor([1.0], device="cuda")
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
def run():
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs MXFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_mxfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_mxfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")

View File

@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
from vllm._custom_ops import fusedQuantizeNv
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"nvfp4": dict(no_a_quant=False, enabled=True),
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_nvfp4(
b: torch.Tensor,
forward_hadamard_matrix: torch.Tensor,
global_scale: torch.Tensor,
device: str,
M: int,
N: int,
K: int,
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
b, forward_hadamard_matrix, global_scale
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
-1, K // 16
)
return weight_hf_e2m1, weight_hf_scale_block
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
alpha = torch.tensor([1.0], device="cuda")
global_scale = torch.tensor([1.0], device="cuda")
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
b, forward_hadamard_matrix, global_scale, device, M, N, K
)
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
def run():
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs NVFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_nvfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [16, 32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_nvfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")

View File

@ -0,0 +1,97 @@
include(FetchContent)
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
if(DEFINED ENV{QUTLASS_SRC_DIR})
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
endif()
if(QUTLASS_SRC_DIR)
FetchContent_Declare(
qutlass
SOURCE_DIR ${QUTLASS_SRC_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
else()
FetchContent_Declare(
qutlass
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
FetchContent_Populate(qutlass)
set(qutlass_SOURCE_DIR "${qutlass_SOURCE_DIR}")
endif()
if(NOT qutlass_SOURCE_DIR)
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
endif()
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
if(QUTLASS_ARCHS MATCHES "10\\.0a")
set(QUTLASS_TARGET_CC 100)
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
set(QUTLASS_TARGET_CC 120)
else()
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
endif()
set(QUTLASS_SOURCES
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
)
set(QUTLASS_INCLUDES
${qutlass_SOURCE_DIR}
${qutlass_SOURCE_DIR}/qutlass
${qutlass_SOURCE_DIR}/qutlass/csrc/include
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
)
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
else()
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
endif()
set_gencode_flags_for_srcs(
SRCS "${QUTLASS_SOURCES}"
CUDA_ARCHS "${QUTLASS_ARCHS}"
)
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
target_compile_definitions(_C PRIVATE
QUTLASS_DISABLE_PYBIND=1
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
)
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
)
else()
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
message(STATUS
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
else()
message(STATUS
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
"CUDA_ARCHS='${CUDA_ARCHS}'.")
endif()
endif()

View File

@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.platforms import current_platform
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests.", allow_module_level=True)
if not (
current_platform.has_device_capability(100)
or current_platform.has_device_capability(120)
):
pytest.skip(
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
allow_module_level=True,
)
# ----- Helpers -----
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _rtne_fp4(x: torch.Tensor):
device = x.device
grid = torch.tensor(
[
-6.0,
-4.0,
-3.0,
-2.0,
-1.5,
-1.0,
-0.5,
-0.0,
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
],
dtype=x.dtype,
device=x.device,
)
grid_int = torch.tensor(
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
dtype=torch.uint8,
device=device,
)
inds = torch.bucketize(x, grid)
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
g_lo, g_hi = grid[lo], grid[hi]
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
y = torch.where(pick_hi, g_hi, g_lo)
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
return y, y_int_packed
def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float):
device = x_e2m1.device
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
x_e2m1_unpacked = torch.stack(
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
).flatten(start_dim=-2)
grid_dq = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
dtype=torch.float64,
device=device,
)
x_fp4_dq = grid_dq[x_e2m1_unpacked]
scales_dq = x_e8m0.to(torch.float64)
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten(
start_dim=-2
) / alpha
return x_dq, x_fp4_dq, scales_dq
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
clip_mask_unpacked_dq = torch.zeros(
*clip_mask.shape[:-1],
clip_mask.size(-1) * 8,
dtype=torch.bool,
device=clip_mask.device,
)
for i in range(8):
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
return clip_mask_unpacked_dq
def _forward_quantize_ref(
x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True
):
device = x.device
xh_ref64 = (
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
).flatten(start_dim=-2)
if quest:
scales_ref64_ = (
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0)
* (2.92247856 / 6.0)
+ 1e-8
)
else:
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1)
scales_ref64_ = abs_max + 1e-8
xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu)
scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64)
xh_scaled_ref64 = (
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None]
).flatten(start_dim=-2)
if not quest:
xh_scaled_ref64 *= 3
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
clip_mask_ref = torch.zeros(
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
)
for i in range(8):
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(
xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0
)
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
assert xh_fp4_dq.equal(xh_fp4_ref)
assert scales_dq.equal(scales_ref64)
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
return (
xh_dq,
clip_mask_unpacked_ref,
(xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref),
)
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda:0")
ROT_SIZES = [32, 64, 128]
SEEDS = [0]
BATCHES = [1, 16]
LLAMA_MODELS = {
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
}
@pytest.fixture(autouse=True)
def _seed_each_test():
current_platform.seed_everything(0)
np.random.seed(0)
torch.random.manual_seed(0)
@pytest.mark.parametrize("rot_size", ROT_SIZES)
@torch.inference_mode()
def test_fused_quantization_absmax(rot_size: int):
dtype, device = DTYPE, DEVICE
h = get_hadamard_matrix(rot_size, dtype, device)
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False)
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max")
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0)
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
m, n, k = 1, 504, 4096
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max")
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max")
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
out_ref = a_dq @ b_dq.transpose(-2, -1)
a_scale_block = to_blocked(a_e8m0, backend="triton")
b_scale_block = to_blocked(b_e8m0, backend="triton")
alpha = torch.tensor([1.0], device=device)
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
assert out.equal(out_ref.to(dtype=out.dtype))
@pytest.mark.parametrize("rot_size", ROT_SIZES)
@torch.inference_mode()
def test_fused_quantization_quest(rot_size: int):
dtype, device = DTYPE, DEVICE
h = get_hadamard_matrix(rot_size, dtype, device)
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True)
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest")
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0)
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
m, n, k = 504, 504, 2048
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
out_ref = a_dq @ b_dq.transpose(-2, -1)
a_scale_block = to_blocked(a_e8m0, backend="triton")
b_scale_block = to_blocked(b_e8m0, backend="triton")
alpha = torch.tensor([1.0], device=device)
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
assert out.equal(out_ref.to(dtype=out.dtype))
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
@pytest.mark.parametrize("batch", [1, 16])
@pytest.mark.parametrize("had_size", ROT_SIZES)
@torch.inference_mode()
def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int):
dtype, device = DTYPE, DEVICE
m = batch
k, n = LLAMA_MODELS[model][layer_idx]
h = get_hadamard_matrix(had_size, dtype, device)
a = torch.rand(m, k, dtype=dtype, device=device) * 25.0
b = torch.rand(n, k, dtype=dtype, device=device) * 25.0
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
out_ref = a_dq @ b_dq.transpose(-2, -1)
a_scale_block = to_blocked(a_e8m0, backend="triton")
b_scale_block = to_blocked(b_e8m0, backend="triton")
alpha = torch.tensor([1.0], device=device)
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
assert out.equal(out_ref.to(dtype=out.dtype))

View File

@ -0,0 +1,268 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
from vllm._custom_ops import fusedQuantizeNv
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.platforms import current_platform
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests.", allow_module_level=True)
if not (
current_platform.has_device_capability(100)
or current_platform.has_device_capability(120)
):
pytest.skip(
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
allow_module_level=True,
)
# ----- Helpers -----
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _rtne_fp4(x: torch.Tensor):
device = x.device
grid = torch.tensor(
[
-6.0,
-4.0,
-3.0,
-2.0,
-1.5,
-1.0,
-0.5,
-0.0,
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
],
dtype=x.dtype,
device=x.device,
)
grid_int = torch.tensor(
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
dtype=torch.uint8,
device=device,
)
inds = torch.bucketize(x, grid)
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
g_lo, g_hi = grid[lo], grid[hi]
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
y = torch.where(pick_hi, g_hi, g_lo)
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
return y, y_int_packed
def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float):
device = x_e2m1.device
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
x_e2m1_unpacked = torch.stack(
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
).flatten(start_dim=-2)
grid_dq = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
dtype=torch.float64,
device=device,
)
x_fp4_dq = grid_dq[x_e2m1_unpacked]
scales_dq = x_e4m3.to(torch.float64)
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten(
start_dim=-2
) / alpha # * (4. / 3.)
return x_dq, x_fp4_dq, scales_dq
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
clip_mask_unpacked_dq = torch.zeros(
*clip_mask.shape[:-1],
clip_mask.size(-1) * 8,
dtype=torch.bool,
device=clip_mask.device,
)
for i in range(8):
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
return clip_mask_unpacked_dq
def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int):
device = x.device
xh_ref64 = (
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
).flatten(start_dim=-2)
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1)
scales_ref64_ = abs_max + 1e-8
xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn)
scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64)
xh_scaled_ref64 = (
xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None]
).flatten(start_dim=-2)
xh_scaled_ref64 *= 6.0
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
clip_mask_ref = torch.zeros(
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
)
for i in range(8):
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0)
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
assert xh_fp4_dq.equal(xh_fp4_ref)
assert scales_dq.equal(scales_ref64)
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
return (
xh_dq,
clip_mask_unpacked_ref,
(xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref),
)
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda:0")
ROT_SIZES = [16, 32, 64, 128]
GLOBAL_SCALES = [6.0]
LLAMA_MODELS = {
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
}
@pytest.fixture(autouse=True)
def _seed_each_test():
current_platform.seed_everything(0)
np.random.seed(0)
torch.random.manual_seed(0)
@pytest.mark.parametrize("rot_size", ROT_SIZES)
@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES)
@torch.inference_mode()
def test_fused_quantization(rot_size: int, global_scale_value: float):
dtype, device = DTYPE, DEVICE
h = get_hadamard_matrix(rot_size, dtype, device)
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
global_scale = torch.tensor([global_scale_value], device=device)
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size)
xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale)
xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16)
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value)
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1
m, n, k = 504, 4096 * 2, 4096
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
out_ref = a_dq @ b_dq.transpose(-2, -1)
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
alpha = torch.tensor([1.0], device=device)
out = ops.cutlass_scaled_fp4_mm(
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
)
assert out.equal(out_ref.to(dtype=out.dtype))
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
@pytest.mark.parametrize("batch", [1, 16])
@pytest.mark.parametrize("rot_size", ROT_SIZES)
@torch.inference_mode()
def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int):
dtype, device = DTYPE, DEVICE
m = batch
k, n = LLAMA_MODELS[model][layer_idx]
h = get_hadamard_matrix(rot_size, dtype, device)
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
global_scale = torch.tensor([1.0], device=device)
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
out_ref = a_dq @ b_dq.transpose(-2, -1)
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
alpha = torch.tensor([1.0], device=device)
out = ops.cutlass_scaled_fp4_mm(
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
)
assert out.equal(out_ref.to(dtype=out.dtype))

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test model set-up and inference for quantized HF models supported
on the GPU backend using FPQuant.
Validating the configuration and printing results for manual checking.
Run `pytest tests/quantization/test_fp_quant.py`.
"""
import pytest
from tests.quantization.utils import is_quant_method_supported
MODELS = [
"ISTA-DASLab/Qwen3-0.6B-RTN-NVFP4",
"ISTA-DASLab/Qwen3-0.6B-RTN-MXFP4",
]
DTYPE = ["bfloat16"]
EAGER = [True, False]
@pytest.mark.skipif(
not is_quant_method_supported("fp_quant"),
reason="FPQuant is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("eager", EAGER)
def test_fpquant(vllm_runner, model, eager):
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
assert output[0][1] == "1 2 3 4 5 6"

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
@ -2483,6 +2483,144 @@ def onednn_scaled_mm(
return output
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
def _fake_matmul_mxf4_bf16_tn(
a: torch.Tensor,
b: torch.Tensor,
a_sf: torch.Tensor,
b_sf: torch.Tensor,
alpha: torch.Tensor,
):
return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)
def matmul_mxf4_bf16_tn(
a: torch.Tensor,
b: torch.Tensor,
a_sf: torch.Tensor,
b_sf: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)
if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn")
def _fake_matmul_ada_mxf4_bf16_tn(
a: torch.Tensor,
b: torch.Tensor,
a_sf: torch.Tensor,
b_sf: torch.Tensor,
alpha: torch.Tensor,
):
return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)
def matmul_ada_mxf4_bf16_tn(
a: torch.Tensor,
b: torch.Tensor,
a_sf: torch.Tensor,
b_sf: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)
def ceil_div(a, b):
return (a + b - 1) // b
if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"):
@register_fake("_qutlass_C::fusedQuantizeMxQuest")
def _fake_fused_quantize_mx_quest(
a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
):
return xh_e2m1, xh_e8m0
if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"):
@register_fake("_qutlass_C::fusedQuantizeMxAbsMax")
def _fake_fused_quantize_mx_absmax(
a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
):
return xh_e2m1, xh_e8m0
def fusedQuantizeMx(
a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest"
) -> tuple[torch.Tensor, torch.Tensor]:
if a.dim() == 0:
raise ValueError("`a` must have at least 1 dimension.")
if a.size(-1) % 32 != 0:
raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.")
if b.device != a.device:
raise ValueError("`a` and `b` must be on the same device.")
xh_e2m1 = torch.empty(
*a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
)
rows, cols = a.numel() // a.size(-1), a.size(-1) // 32
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
xh_e8m0 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device
)
if not hasattr(torch.ops, "_qutlass_C"):
raise RuntimeError(
"The `_qutlass_C` extension is not loaded. "
"Make sure your custom op library is imported before calling fusedQuantizeMx."
)
if method == "quest":
return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0)
elif method == "abs_max":
return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0)
else:
raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'")
if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"):
@register_fake("_qutlass_C::fusedQuantizeNv")
def _fake_fused_quantize_nv(
a: torch.Tensor,
b: torch.Tensor,
xh_e2m1: torch.Tensor,
xh_e4m3: torch.Tensor,
global_scale: torch.Tensor,
):
return xh_e2m1, xh_e4m3
def fusedQuantizeNv(
a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
xh_e2m1 = torch.empty(
*a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
)
rows, cols = a.numel() // a.size(-1), a.size(-1) // 16
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
xh_e4m3 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device
)
return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale)
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)

View File

@ -12,6 +12,7 @@ QuantizationMethods = Literal[
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"fp_quant",
"modelopt",
"modelopt_fp4",
"bitblas",
@ -102,6 +103,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .fp_quant import FPQuantConfig
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_bitblas import GPTQBitBLASConfig
@ -125,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"fp_quant": FPQuantConfig,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"bitblas": BitBLASConfig,

View File

@ -0,0 +1,420 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._custom_ops import (
cutlass_scaled_fp4_mm,
fusedQuantizeMx,
fusedQuantizeNv,
matmul_mxf4_bf16_tn,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig):
"""Config class for FPQuant."""
def __init__(
self,
hadamard_group_size: int = 32,
forward_dtype: str = "mxfp4",
forward_method: str = "abs_max",
pseudoquantization: bool = False,
modules_to_not_convert: Optional[list[str]] = None,
) -> None:
super().__init__()
self.hadamard_group_size = hadamard_group_size
self.forward_dtype = forward_dtype
self.forward_method = forward_method
self.pseudoquantization = pseudoquantization
self.modules_to_not_convert = modules_to_not_convert
if pseudoquantization:
raise ValueError("Pseudoquantization is not supported for vLLM")
def __repr__(self) -> str:
return (
f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, "
f"forward_dtype={self.forward_dtype}, "
f"forward_method={self.forward_method}, "
f"pseudoquantization={self.pseudoquantization}, "
f"modules_to_not_convert={self.modules_to_not_convert})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "fp_quant"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig":
hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"])
forward_dtype = cls.get_from_keys(config, ["forward_dtype"])
forward_method = cls.get_from_keys(config, ["forward_method"])
pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"])
modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"])
return cls(
hadamard_group_size,
forward_dtype,
forward_method,
pseudoquantization,
modules_to_not_convert,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]:
if self.modules_to_not_convert is not None and any(
prefix.endswith(module) for module in self.modules_to_not_convert
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
return FPQuantLinearMethod(self)
return None
class FPQuantLinearMethod(LinearMethodBase):
"""Linear method for FPQuant.
Args:
quant_config: The FPQuant quantization config.
"""
def __init__(self, quant_config: FPQuantConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.bfloat16:
raise ValueError("Only bfloat16 is currently supported by FPQuant")
if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size. Or other skill issues."
)
assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], (
"Only mxfp4 and nvfp4 are supported for now"
)
if self.quant_config.forward_dtype == "mxfp4":
group_size = 32
elif self.quant_config.forward_dtype == "nvfp4":
group_size = 16
else:
raise ValueError(
f"Unsupported forward_dtype: {self.quant_config.forward_dtype}"
)
qweight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": 2,
}
| extra_weight_attrs,
)
layer.register_parameter("qweight", qweight)
scales = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": group_size,
}
| extra_weight_attrs,
)
layer.register_parameter("scales", scales)
weight_global_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
weight_global_scale, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("weight_global_scale", weight_global_scale)
act_global_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(
act_global_scale, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("act_global_scale", act_global_scale)
forward_hadamard_matrix = Parameter(
torch.empty(
self.quant_config.hadamard_group_size,
self.quant_config.hadamard_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix)
backward_hadamard_matrix = Parameter(
torch.empty(
self.quant_config.hadamard_group_size,
self.quant_config.hadamard_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
)
layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return quantized_forward(
x,
layer.qweight,
layer.scales,
layer.weight_global_scale,
layer.act_global_scale,
bias,
layer.forward_hadamard_matrix,
self.quant_config.forward_method,
self.quant_config.forward_dtype,
)
def ceil_div(a, b):
return (a + b - 1) // b
def fused_quantize_mx(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
rows, cols = x_flat.size(0), x_flat.size(1) // 32
padded_rows = ((rows + 128 - 1) // 128) * 128
padded_cols = ((cols + 4 - 1) // 4) * 4
xh_e2m1 = torch.empty(
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
)
xh_e8m0 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device
)
return xh_e2m1, xh_e8m0
direct_register_custom_op(
op_name="fused_quantize_mx",
op_func=fused_quantize_mx,
mutates_args=[],
fake_impl=fused_quantize_mx_fake,
dispatch_key=current_platform.dispatch_key,
)
def matmul_mxf4_bf16(
x: torch.Tensor,
w: torch.Tensor,
xs: torch.Tensor,
ws: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return matmul_mxf4_bf16_tn(
x,
w,
to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu),
to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu),
alpha,
)
def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha):
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
direct_register_custom_op(
op_name="matmul_mxf4_bf16",
op_func=matmul_mxf4_bf16,
mutates_args=[],
fake_impl=matmul_mxf4_bf16_fake,
dispatch_key=current_platform.dispatch_key,
)
def fused_quantize_nv(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale)
def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale):
rows, cols = x_flat.size(0), x_flat.size(1) // 16
padded_rows = ((rows + 128 - 1) // 128) * 128
padded_cols = ((cols + 4 - 1) // 4) * 4
xh_e2m1 = torch.empty(
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
)
xh_e8m0 = torch.empty(
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device
)
return xh_e2m1, xh_e8m0
direct_register_custom_op(
op_name="fused_quantize_nv",
op_func=fused_quantize_nv,
mutates_args=[],
fake_impl=fused_quantize_nv_fake,
dispatch_key=current_platform.dispatch_key,
)
def matmul_nvf4_bf16(
x: torch.Tensor,
w: torch.Tensor,
xs: torch.Tensor,
ws: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return cutlass_scaled_fp4_mm(
x,
w,
to_blocked(xs, backend="triton")
.view(torch.float8_e4m3fn)
.view(-1, x.shape[1] // 8), # *2//16
to_blocked(ws, backend="triton")
.view(torch.float8_e4m3fn)
.view(-1, x.shape[1] // 8),
alpha,
torch.bfloat16,
)
def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha):
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
direct_register_custom_op(
op_name="matmul_nvf4_bf16",
op_func=matmul_nvf4_bf16,
mutates_args=[],
fake_impl=matmul_nvf4_bf16_fake,
dispatch_key=current_platform.dispatch_key,
)
def quantized_forward(
x: torch.Tensor,
qweight: torch.Tensor,
weight_scales: torch.Tensor,
weight_global_scale: torch.Tensor,
act_global_scale: torch.Tensor,
bias: Optional[torch.Tensor],
forward_hadamard_matrix: torch.Tensor,
forward_method: str,
forward_dtype: str,
) -> torch.Tensor:
x_flat = x.contiguous().flatten(end_dim=-2)
if forward_dtype == "mxfp4":
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx(
x_flat, forward_hadamard_matrix, forward_method
)
y = torch.ops.vllm.matmul_mxf4_bf16(
x_flat_q,
qweight,
x_flat_scales,
weight_scales,
1 / (weight_global_scale * act_global_scale),
)
elif forward_dtype == "nvfp4":
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv(
x_flat, forward_hadamard_matrix, act_global_scale
)
y = torch.ops.vllm.matmul_nvf4_bf16(
x_flat_q,
qweight,
x_flat_scales,
weight_scales,
1 / (weight_global_scale * act_global_scale),
)
else:
raise ValueError(f"Unsupported forward_dtype: {forward_dtype}")
y = y.view(*x.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y

View File

@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
#
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Literal
import torch
import triton
import triton.language as tl
from torch.library import wrap_triton
@triton.jit
def triton_scale_swizzle(
scale_ptr: torch.Tensor,
scale_rows: int,
scale_cols: int,
output_ptr: torch.Tensor,
input_row_stride: int,
output_block_stride: int,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
"""
Rearranges tensor data from row-major to block-scaled swizzle format.
Args:
scale_ptr: Pointer to the input scale tensor
scale_rows: Number of rows in the scale tensor
scale_cols: Number of columns in the scale tensor
output_ptr: Pointer to the output tensor
input_row_stride: Stride between rows in the input tensor
output_block_stride: Stride between blocks in the output tensor
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
BLOCK_COLS: Number of columns in a tile (compile-time constant)
"""
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
rows = tl.arange(0, BLOCK_ROWS)[:, None]
cols = tl.arange(0, BLOCK_COLS)[None, :]
# Calculate starting row and column for this tile
start_row = pid_row * BLOCK_ROWS
start_col = pid_col * BLOCK_COLS
global_rows = start_row + rows
global_cols = start_col + cols
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
input_scales = tl.load(
scale_ptr + global_rows * input_row_stride + global_cols,
mask=mask,
other=0.0,
)
r_div_32 = rows // 32
r_mod_32 = rows % 32
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
# Calculate block offset using provided output block stride
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
tl.store(
output_ptr + block_offset + dest_indices_flat,
scales_flat,
)
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale from row-major format to
block-scaled swizzle format.
This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
scale_tensor: Input tensor in row-major format with 8-bit elements
Returns:
Rearranged tensor in block-scaled swizzle format
"""
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
rows, cols = scale_tensor.shape
# Calculate blocks needed
n_row_blocks = triton.cdiv(rows, 128)
n_col_blocks = triton.cdiv(cols, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
out = scale_tensor.new_empty((padded_rows, padded_cols))
# Input stride (for row-major format)
input_row_stride = cols
# We probably want handle multiple blocks per tile but
# for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4
# Output block stride for the rearranged format
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
grid = lambda META: (
triton.cdiv(padded_rows, BLOCK_ROWS),
triton.cdiv(padded_cols, BLOCK_COLS),
)
wrap_triton(triton_scale_swizzle)[grid](
scale_tensor.view(torch.uint8),
rows,
cols,
out.view(torch.uint8),
input_row_stride,
output_block_stride,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
)
return out
def ceil_div(a, b):
return (a + b - 1) // b
def to_blocked(
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying
the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
if backend == "triton":
return triton_mx_block_rearrange(input_matrix).flatten()
elif backend != "torch":
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
assert (rows, cols) == (padded_rows, padded_cols)
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()