From 96ad65b7fe515663da8ede09a1aa7f74aa500c97 Mon Sep 17 00:00:00 2001 From: "Roberto L. Castro" <38211239+LopezCastroRoberto@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:43:40 +0200 Subject: [PATCH] [Transform] [Quantization] Add QuTLASS support to vLLM (#24440) Signed-off-by: LopezCastroRoberto Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Signed-off-by: Andrei Panferov Co-authored-by: Andrei Panferov Co-authored-by: Michael Goin --- .buildkite/test-pipeline.yaml | 2 + CMakeLists.txt | 1 + benchmarks/kernels/bench_mxfp4_qutlass.py | 191 ++++++++ benchmarks/kernels/bench_nvfp4_qutlass.py | 207 +++++++++ cmake/external_projects/qutlass.cmake | 97 ++++ .../quantization/test_mxfp4_qutlass.py | 303 +++++++++++++ .../quantization/test_nvfp4_qutlass.py | 268 +++++++++++ tests/quantization/fp_quant.py | 32 ++ vllm/_custom_ops.py | 140 +++++- .../layers/quantization/__init__.py | 3 + .../layers/quantization/fp_quant.py | 420 ++++++++++++++++++ .../layers/quantization/qutlass_utils.py | 185 ++++++++ 12 files changed, 1848 insertions(+), 1 deletion(-) create mode 100644 benchmarks/kernels/bench_mxfp4_qutlass.py create mode 100644 benchmarks/kernels/bench_nvfp4_qutlass.py create mode 100644 cmake/external_projects/qutlass.cmake create mode 100644 tests/kernels/quantization/test_mxfp4_qutlass.py create mode 100644 tests/kernels/quantization/test_nvfp4_qutlass.py create mode 100644 tests/quantization/fp_quant.py create mode 100644 vllm/model_executor/layers/quantization/fp_quant.py create mode 100644 vllm/model_executor/layers/quantization/qutlass_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d80ae054d6..681503e983 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -834,6 +834,8 @@ steps: - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 diff --git a/CMakeLists.txt b/CMakeLists.txt index db72244e0e..0055904453 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1007,6 +1007,7 @@ endif() # For CUDA we also build and ship some external projects. if (VLLM_GPU_LANG STREQUAL "CUDA") include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/qutlass.cmake) # vllm-flash-attn should be last as it overwrites some CMake functions include(cmake/external_projects/vllm_flash_attn.cmake) diff --git a/benchmarks/kernels/bench_mxfp4_qutlass.py b/benchmarks/kernels/bench_mxfp4_qutlass.py new file mode 100644 index 0000000000..dfc7721876 --- /dev/null +++ b/benchmarks/kernels/bench_mxfp4_qutlass.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "mxfp4": dict(no_a_quant=False, enabled=True), + "mxfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_mxfp4( + b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx( + b, forward_hadamard_matrix, method="abs_max" + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton") + return weight_hf_e2m1, weight_hf_scale_block + + +def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4( + b, forward_hadamard_matrix, device + ) + alpha = torch.tensor([1.0], device="cuda") + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + + def run(): + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs MXFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_mxfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_mxfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_nvfp4_qutlass.py b/benchmarks/kernels/bench_nvfp4_qutlass.py new file mode 100644 index 0000000000..6fecc816f9 --- /dev/null +++ b/benchmarks/kernels/bench_nvfp4_qutlass.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_nvfp4( + b: torch.Tensor, + forward_hadamard_matrix: torch.Tensor, + global_scale: torch.Tensor, + device: str, + M: int, + N: int, + K: int, +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv( + b, forward_hadamard_matrix, global_scale + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return weight_hf_e2m1, weight_hf_scale_block + + +def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K): + alpha = torch.tensor([1.0], device="cuda") + global_scale = torch.tensor([1.0], device="cuda") + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4( + b, forward_hadamard_matrix, global_scale, device, M, N, K + ) + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + + def run(): + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [16, 32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_nvfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake new file mode 100644 index 0000000000..9aace76930 --- /dev/null +++ b/cmake/external_projects/qutlass.cmake @@ -0,0 +1,97 @@ +include(FetchContent) + +set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory") + +if(DEFINED ENV{QUTLASS_SRC_DIR}) + set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR}) +endif() + +if(QUTLASS_SRC_DIR) + FetchContent_Declare( + qutlass + SOURCE_DIR ${QUTLASS_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + qutlass + GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git + GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) + FetchContent_Populate(qutlass) + set(qutlass_SOURCE_DIR "${qutlass_SOURCE_DIR}") +endif() + +if(NOT qutlass_SOURCE_DIR) + message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.") +endif() +message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") + +cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS) + + if(QUTLASS_ARCHS MATCHES "10\\.0a") + set(QUTLASS_TARGET_CC 100) + elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + set(QUTLASS_TARGET_CC 120) + else() + message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") + endif() + + set(QUTLASS_SOURCES + ${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu + ) + + set(QUTLASS_INCLUDES + ${qutlass_SOURCE_DIR} + ${qutlass_SOURCE_DIR}/qutlass + ${qutlass_SOURCE_DIR}/qutlass/csrc/include + ${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions + ) + + if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}") + elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include") + message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).") + else() + message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. " + "Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include") + endif() + + set_gencode_flags_for_srcs( + SRCS "${QUTLASS_SOURCES}" + CUDA_ARCHS "${QUTLASS_ARCHS}" + ) + + target_sources(_C PRIVATE ${QUTLASS_SOURCES}) + target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES}) + target_compile_definitions(_C PRIVATE + QUTLASS_DISABLE_PYBIND=1 + TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC} + ) + + set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS + $<$:--expt-relaxed-constexpr --use_fast_math -O3> + ) + +else() + if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8") + message(STATUS + "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") + else() + message(STATUS + "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "CUDA_ARCHS='${CUDA_ARCHS}'.") + endif() +endif() diff --git a/tests/kernels/quantization/test_mxfp4_qutlass.py b/tests/kernels/quantization/test_mxfp4_qutlass.py new file mode 100644 index 0000000000..0bacbef204 --- /dev/null +++ b/tests/kernels/quantization/test_mxfp4_qutlass.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + scales_dq = x_e8m0.to(torch.float64) + + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref( + x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True +): + device = x.device + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + if quest: + scales_ref64_ = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0) + * (2.92247856 / 6.0) + + 1e-8 + ) + else: + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu) + scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64) + + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + if not quest: + xh_scaled_ref64 *= 3 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4( + xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0 + ) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") + +ROT_SIZES = [32, 64, 128] +SEEDS = [0] +BATCHES = [1, 16] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_absmax(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 1, 504, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_quest(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 504, 504, 2048 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("had_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(had_size, dtype, device) + + a = torch.rand(m, k, dtype=dtype, device=device) * 25.0 + b = torch.rand(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/kernels/quantization/test_nvfp4_qutlass.py b/tests/kernels/quantization/test_nvfp4_qutlass.py new file mode 100644 index 0000000000..3824a080f5 --- /dev/null +++ b/tests/kernels/quantization/test_nvfp4_qutlass.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + + scales_dq = x_e4m3.to(torch.float64) + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha # * (4. / 3.) + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int): + device = x.device + + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn) + scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64) + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + + xh_scaled_ref64 *= 6.0 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") +ROT_SIZES = [16, 32, 64, 128] +GLOBAL_SCALES = [6.0] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES) +@torch.inference_mode() +def test_fused_quantization(rot_size: int, global_scale_value: float): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + global_scale = torch.tensor([global_scale_value], device=device) + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size) + xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale) + xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1 + + m, n, k = 504, 4096 * 2, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(rot_size, dtype, device) + + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + global_scale = torch.tensor([1.0], device=device) + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/quantization/fp_quant.py b/tests/quantization/fp_quant.py new file mode 100644 index 0000000000..664ce9d111 --- /dev/null +++ b/tests/quantization/fp_quant.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test model set-up and inference for quantized HF models supported +on the GPU backend using FPQuant. + +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_fp_quant.py`. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +MODELS = [ + "ISTA-DASLab/Qwen3-0.6B-RTN-NVFP4", + "ISTA-DASLab/Qwen3-0.6B-RTN-MXFP4", +] +DTYPE = ["bfloat16"] +EAGER = [True, False] + + +@pytest.mark.skipif( + not is_quant_method_supported("fp_quant"), + reason="FPQuant is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("eager", EAGER) +def test_fpquant(vllm_runner, model, eager): + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b2a7f8e808..eac0a5009e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import torch @@ -2483,6 +2483,144 @@ def onednn_scaled_mm( return output +if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") + def _fake_matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn") + def _fake_matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"): + + @register_fake("_qutlass_C::fusedQuantizeMxQuest") + def _fake_fused_quantize_mx_quest( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"): + + @register_fake("_qutlass_C::fusedQuantizeMxAbsMax") + def _fake_fused_quantize_mx_absmax( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +def fusedQuantizeMx( + a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest" +) -> tuple[torch.Tensor, torch.Tensor]: + if a.dim() == 0: + raise ValueError("`a` must have at least 1 dimension.") + if a.size(-1) % 32 != 0: + raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.") + if b.device != a.device: + raise ValueError("`a` and `b` must be on the same device.") + + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 32 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device + ) + + if not hasattr(torch.ops, "_qutlass_C"): + raise RuntimeError( + "The `_qutlass_C` extension is not loaded. " + "Make sure your custom op library is imported before calling fusedQuantizeMx." + ) + + if method == "quest": + return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) + elif method == "abs_max": + return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0) + else: + raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'") + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"): + + @register_fake("_qutlass_C::fusedQuantizeNv") + def _fake_fused_quantize_nv( + a: torch.Tensor, + b: torch.Tensor, + xh_e2m1: torch.Tensor, + xh_e4m3: torch.Tensor, + global_scale: torch.Tensor, + ): + return xh_e2m1, xh_e4m3 + + +def fusedQuantizeNv( + a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 16 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + xh_e4m3 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device + ) + + return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale) + + def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: """ Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9d1c66e56e..b92fb8d266 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -12,6 +12,7 @@ QuantizationMethods = Literal[ "fp8", "ptpc_fp8", "fbgemm_fp8", + "fp_quant", "modelopt", "modelopt_fp4", "bitblas", @@ -102,6 +103,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config from .fp8 import Fp8Config + from .fp_quant import FPQuantConfig from .gguf import GGUFConfig from .gptq import GPTQConfig from .gptq_bitblas import GPTQBitBLASConfig @@ -125,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, + "fp_quant": FPQuantConfig, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, "bitblas": BitBLASConfig, diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py new file mode 100644 index 0000000000..929e603149 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202 + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._custom_ops import ( + cutlass_scaled_fp4_mm, + fusedQuantizeMx, + fusedQuantizeNv, + matmul_mxf4_bf16_tn, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +class FPQuantConfig(QuantizationConfig): + """Config class for FPQuant.""" + + def __init__( + self, + hadamard_group_size: int = 32, + forward_dtype: str = "mxfp4", + forward_method: str = "abs_max", + pseudoquantization: bool = False, + modules_to_not_convert: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.hadamard_group_size = hadamard_group_size + self.forward_dtype = forward_dtype + self.forward_method = forward_method + self.pseudoquantization = pseudoquantization + self.modules_to_not_convert = modules_to_not_convert + + if pseudoquantization: + raise ValueError("Pseudoquantization is not supported for vLLM") + + def __repr__(self) -> str: + return ( + f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, " + f"forward_dtype={self.forward_dtype}, " + f"forward_method={self.forward_method}, " + f"pseudoquantization={self.pseudoquantization}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "fp_quant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig": + hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"]) + forward_dtype = cls.get_from_keys(config, ["forward_dtype"]) + forward_method = cls.get_from_keys(config, ["forward_method"]) + pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"]) + modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"]) + return cls( + hadamard_group_size, + forward_dtype, + forward_method, + pseudoquantization, + modules_to_not_convert, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[LinearMethodBase]: + if self.modules_to_not_convert is not None and any( + prefix.endswith(module) for module in self.modules_to_not_convert + ): + return UnquantizedLinearMethod() + + if isinstance(layer, LinearBase): + return FPQuantLinearMethod(self) + return None + + +class FPQuantLinearMethod(LinearMethodBase): + """Linear method for FPQuant. + + Args: + quant_config: The FPQuant quantization config. + """ + + def __init__(self, quant_config: FPQuantConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.bfloat16: + raise ValueError("Only bfloat16 is currently supported by FPQuant") + if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501 + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size. Or other skill issues." + ) + + assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], ( + "Only mxfp4 and nvfp4 are supported for now" + ) + if self.quant_config.forward_dtype == "mxfp4": + group_size = 32 + elif self.quant_config.forward_dtype == "nvfp4": + group_size = 16 + else: + raise ValueError( + f"Unsupported forward_dtype: {self.quant_config.forward_dtype}" + ) + + qweight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": 2, + } + | extra_weight_attrs, + ) + layer.register_parameter("qweight", qweight) + + scales = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": group_size, + } + | extra_weight_attrs, + ) + layer.register_parameter("scales", scales) + + weight_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + weight_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + act_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + act_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("act_global_scale", act_global_scale) + + forward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix) + + backward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return quantized_forward( + x, + layer.qweight, + layer.scales, + layer.weight_global_scale, + layer.act_global_scale, + bias, + layer.forward_hadamard_matrix, + self.quant_config.forward_method, + self.quant_config.forward_dtype, + ) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def fused_quantize_mx( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method) + + +def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method): + rows, cols = x_flat.size(0), x_flat.size(1) // 32 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_mx", + op_func=fused_quantize_mx, + mutates_args=[], + fake_impl=fused_quantize_mx_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_mxf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return matmul_mxf4_bf16_tn( + x, + w, + to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu), + to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu), + alpha, + ) + + +def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_mxf4_bf16", + op_func=matmul_mxf4_bf16, + mutates_args=[], + fake_impl=matmul_mxf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def fused_quantize_nv( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale) + + +def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale): + rows, cols = x_flat.size(0), x_flat.size(1) // 16 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_nv", + op_func=fused_quantize_nv, + mutates_args=[], + fake_impl=fused_quantize_nv_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_nvf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return cutlass_scaled_fp4_mm( + x, + w, + to_blocked(xs, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), # *2//16 + to_blocked(ws, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), + alpha, + torch.bfloat16, + ) + + +def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_nvf4_bf16", + op_func=matmul_nvf4_bf16, + mutates_args=[], + fake_impl=matmul_nvf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def quantized_forward( + x: torch.Tensor, + qweight: torch.Tensor, + weight_scales: torch.Tensor, + weight_global_scale: torch.Tensor, + act_global_scale: torch.Tensor, + bias: Optional[torch.Tensor], + forward_hadamard_matrix: torch.Tensor, + forward_method: str, + forward_dtype: str, +) -> torch.Tensor: + x_flat = x.contiguous().flatten(end_dim=-2) + + if forward_dtype == "mxfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx( + x_flat, forward_hadamard_matrix, forward_method + ) + y = torch.ops.vllm.matmul_mxf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + elif forward_dtype == "nvfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv( + x_flat, forward_hadamard_matrix, act_global_scale + ) + y = torch.ops.vllm.matmul_nvf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + else: + raise ValueError(f"Unsupported forward_dtype: {forward_dtype}") + + y = y.view(*x.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + + return y diff --git a/vllm/model_executor/layers/quantization/qutlass_utils.py b/vllm/model_executor/layers/quantization/qutlass_utils.py new file mode 100644 index 0000000000..395bde76d0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/qutlass_utils.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# +# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal + +import torch +import triton +import triton.language as tl +from torch.library import wrap_triton + + +@triton.jit +def triton_scale_swizzle( + scale_ptr: torch.Tensor, + scale_rows: int, + scale_cols: int, + output_ptr: torch.Tensor, + input_row_stride: int, + output_block_stride: int, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + """ + Rearranges tensor data from row-major to block-scaled swizzle format. + + Args: + scale_ptr: Pointer to the input scale tensor + scale_rows: Number of rows in the scale tensor + scale_cols: Number of columns in the scale tensor + output_ptr: Pointer to the output tensor + input_row_stride: Stride between rows in the input tensor + output_block_stride: Stride between blocks in the output tensor + BLOCK_ROWS: Number of rows in a tile (compile-time constant) + BLOCK_COLS: Number of columns in a tile (compile-time constant) + """ + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + rows = tl.arange(0, BLOCK_ROWS)[:, None] + cols = tl.arange(0, BLOCK_COLS)[None, :] + + # Calculate starting row and column for this tile + start_row = pid_row * BLOCK_ROWS + start_col = pid_col * BLOCK_COLS + global_rows = start_row + rows + global_cols = start_col + cols + + mask = (global_rows < scale_rows) & (global_cols < scale_cols) + + input_scales = tl.load( + scale_ptr + global_rows * input_row_stride + global_cols, + mask=mask, + other=0.0, + ) + + r_div_32 = rows // 32 + r_mod_32 = rows % 32 + + # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS + block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride) + + tl.store( + output_ptr + block_offset + dest_indices_flat, + scales_flat, + ) + + +def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale from row-major format to + block-scaled swizzle format. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scale_tensor: Input tensor in row-major format with 8-bit elements + + Returns: + Rearranged tensor in block-scaled swizzle format + """ + assert scale_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + assert scale_tensor.is_contiguous(), "Input tensor must be contiguous" + + rows, cols = scale_tensor.shape + + # Calculate blocks needed + n_row_blocks = triton.cdiv(rows, 128) + n_col_blocks = triton.cdiv(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + out = scale_tensor.new_empty((padded_rows, padded_cols)) + + # Input stride (for row-major format) + input_row_stride = cols + + # We probably want handle multiple blocks per tile but + # for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + + # Output block stride for the rearranged format + output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + + grid = lambda META: ( + triton.cdiv(padded_rows, BLOCK_ROWS), + triton.cdiv(padded_cols, BLOCK_COLS), + ) + + wrap_triton(triton_scale_swizzle)[grid]( + scale_tensor.view(torch.uint8), + rows, + cols, + out.view(torch.uint8), + input_row_stride, + output_block_stride, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + return out + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked( + input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton" +) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying + the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + backend: "torch" (PyTorch path) or "triton" (Triton kernel) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + if backend == "triton": + return triton_mx_block_rearrange(input_matrix).flatten() + elif backend != "torch": + raise ValueError(f'backend must be "torch" or "triton", got {backend!r}') + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + assert (rows, cols) == (padded_rows, padded_cols) + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten()