mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Transform] [Quantization] Add QuTLASS support to vLLM (#24440)
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Signed-off-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
8d2b8c0ff2
commit
96ad65b7fe
@ -834,6 +834,8 @@ steps:
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
|
||||
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
|
@ -1007,6 +1007,7 @@ endif()
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
include(cmake/external_projects/qutlass.cmake)
|
||||
|
||||
# vllm-flash-attn should be last as it overwrites some CMake functions
|
||||
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||
|
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"mxfp4": dict(no_a_quant=False, enabled=True),
|
||||
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _quant_weight_mxfp4(
|
||||
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
|
||||
):
|
||||
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
|
||||
b, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
|
||||
return weight_hf_e2m1, weight_hf_scale_block
|
||||
|
||||
|
||||
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
|
||||
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
|
||||
b, forward_hadamard_matrix, device
|
||||
)
|
||||
alpha = torch.tensor([1.0], device="cuda")
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||
a, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||
|
||||
def run():
|
||||
return matmul_mxf4_bf16_tn(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||
a, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||
return matmul_mxf4_bf16_tn(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
24576,
|
||||
32768,
|
||||
],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs MXFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, had_size):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_mxfp4_runner(
|
||||
cfg, a, b, forward_hadamard_matrix, dtype, device
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
for had_size in [32, 64, 128]:
|
||||
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_mxfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
had_size=had_size,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||
from vllm._custom_ops import fusedQuantizeNv
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(
|
||||
b: torch.Tensor,
|
||||
forward_hadamard_matrix: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
device: str,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
):
|
||||
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
|
||||
b, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
return weight_hf_e2m1, weight_hf_scale_block
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
|
||||
alpha = torch.tensor([1.0], device="cuda")
|
||||
global_scale = torch.tensor([1.0], device="cuda")
|
||||
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
|
||||
b, forward_hadamard_matrix, global_scale, device, M, N, K
|
||||
)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||
a, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||
a, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
24576,
|
||||
32768,
|
||||
],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, had_size):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(
|
||||
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
for had_size in [16, 32, 64, 128]:
|
||||
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
had_size=had_size,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
97
cmake/external_projects/qutlass.cmake
Normal file
97
cmake/external_projects/qutlass.cmake
Normal file
@ -0,0 +1,97 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
|
||||
|
||||
if(DEFINED ENV{QUTLASS_SRC_DIR})
|
||||
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(QUTLASS_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
qutlass
|
||||
SOURCE_DIR ${QUTLASS_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
qutlass
|
||||
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
|
||||
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
FetchContent_Populate(qutlass)
|
||||
set(qutlass_SOURCE_DIR "${qutlass_SOURCE_DIR}")
|
||||
endif()
|
||||
|
||||
if(NOT qutlass_SOURCE_DIR)
|
||||
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
|
||||
endif()
|
||||
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
|
||||
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
|
||||
|
||||
if(QUTLASS_ARCHS MATCHES "10\\.0a")
|
||||
set(QUTLASS_TARGET_CC 100)
|
||||
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
|
||||
set(QUTLASS_TARGET_CC 120)
|
||||
else()
|
||||
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
|
||||
endif()
|
||||
|
||||
set(QUTLASS_SOURCES
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
|
||||
)
|
||||
|
||||
set(QUTLASS_INCLUDES
|
||||
${qutlass_SOURCE_DIR}
|
||||
${qutlass_SOURCE_DIR}/qutlass
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/include
|
||||
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
|
||||
)
|
||||
|
||||
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
|
||||
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
|
||||
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
|
||||
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
|
||||
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
|
||||
else()
|
||||
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
|
||||
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${QUTLASS_SOURCES}"
|
||||
CUDA_ARCHS "${QUTLASS_ARCHS}"
|
||||
)
|
||||
|
||||
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
|
||||
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
|
||||
target_compile_definitions(_C PRIVATE
|
||||
QUTLASS_DISABLE_PYBIND=1
|
||||
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
|
||||
)
|
||||
|
||||
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
|
||||
)
|
||||
|
||||
else()
|
||||
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
|
||||
message(STATUS
|
||||
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
|
||||
else()
|
||||
message(STATUS
|
||||
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
|
||||
"CUDA_ARCHS='${CUDA_ARCHS}'.")
|
||||
endif()
|
||||
endif()
|
303
tests/kernels/quantization/test_mxfp4_qutlass.py
Normal file
303
tests/kernels/quantization/test_mxfp4_qutlass.py
Normal file
@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
|
||||
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for these tests.", allow_module_level=True)
|
||||
|
||||
if not (
|
||||
current_platform.has_device_capability(100)
|
||||
or current_platform.has_device_capability(120)
|
||||
):
|
||||
pytest.skip(
|
||||
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
# ----- Helpers -----
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _rtne_fp4(x: torch.Tensor):
|
||||
device = x.device
|
||||
grid = torch.tensor(
|
||||
[
|
||||
-6.0,
|
||||
-4.0,
|
||||
-3.0,
|
||||
-2.0,
|
||||
-1.5,
|
||||
-1.0,
|
||||
-0.5,
|
||||
-0.0,
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
grid_int = torch.tensor(
|
||||
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
inds = torch.bucketize(x, grid)
|
||||
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
|
||||
g_lo, g_hi = grid[lo], grid[hi]
|
||||
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
|
||||
y = torch.where(pick_hi, g_hi, g_lo)
|
||||
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
|
||||
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
|
||||
return y, y_int_packed
|
||||
|
||||
|
||||
def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float):
|
||||
device = x_e2m1.device
|
||||
|
||||
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
|
||||
x_e2m1_unpacked = torch.stack(
|
||||
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
grid_dq = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
],
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
x_fp4_dq = grid_dq[x_e2m1_unpacked]
|
||||
scales_dq = x_e8m0.to(torch.float64)
|
||||
|
||||
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten(
|
||||
start_dim=-2
|
||||
) / alpha
|
||||
return x_dq, x_fp4_dq, scales_dq
|
||||
|
||||
|
||||
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
|
||||
clip_mask_unpacked_dq = torch.zeros(
|
||||
*clip_mask.shape[:-1],
|
||||
clip_mask.size(-1) * 8,
|
||||
dtype=torch.bool,
|
||||
device=clip_mask.device,
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
|
||||
return clip_mask_unpacked_dq
|
||||
|
||||
|
||||
def _forward_quantize_ref(
|
||||
x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True
|
||||
):
|
||||
device = x.device
|
||||
xh_ref64 = (
|
||||
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
|
||||
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
if quest:
|
||||
scales_ref64_ = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0)
|
||||
* (2.92247856 / 6.0)
|
||||
+ 1e-8
|
||||
)
|
||||
else:
|
||||
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1)
|
||||
scales_ref64_ = abs_max + 1e-8
|
||||
|
||||
xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu)
|
||||
scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64)
|
||||
|
||||
xh_scaled_ref64 = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None]
|
||||
).flatten(start_dim=-2)
|
||||
if not quest:
|
||||
xh_scaled_ref64 *= 3
|
||||
|
||||
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
|
||||
clip_mask_ref = torch.zeros(
|
||||
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
|
||||
|
||||
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
|
||||
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(
|
||||
xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0
|
||||
)
|
||||
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
|
||||
|
||||
assert xh_fp4_dq.equal(xh_fp4_ref)
|
||||
assert scales_dq.equal(scales_ref64)
|
||||
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
|
||||
|
||||
return (
|
||||
xh_dq,
|
||||
clip_mask_unpacked_ref,
|
||||
(xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref),
|
||||
)
|
||||
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
DEVICE = torch.device("cuda:0")
|
||||
|
||||
ROT_SIZES = [32, 64, 128]
|
||||
SEEDS = [0]
|
||||
BATCHES = [1, 16]
|
||||
|
||||
LLAMA_MODELS = {
|
||||
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
|
||||
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
|
||||
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
|
||||
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _seed_each_test():
|
||||
current_platform.seed_everything(0)
|
||||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization_absmax(rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False)
|
||||
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max")
|
||||
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
|
||||
|
||||
m, n, k = 1, 504, 4096
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max")
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization_quest(rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True)
|
||||
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest")
|
||||
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
|
||||
|
||||
m, n, k = 504, 504, 2048
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
|
||||
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("batch", [1, 16])
|
||||
@pytest.mark.parametrize("had_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
m = batch
|
||||
k, n = LLAMA_MODELS[model][layer_idx]
|
||||
|
||||
h = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
a = torch.rand(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.rand(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
268
tests/kernels/quantization/test_nvfp4_qutlass.py
Normal file
268
tests/kernels/quantization/test_nvfp4_qutlass.py
Normal file
@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
|
||||
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||
from vllm._custom_ops import fusedQuantizeNv
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for these tests.", allow_module_level=True)
|
||||
|
||||
if not (
|
||||
current_platform.has_device_capability(100)
|
||||
or current_platform.has_device_capability(120)
|
||||
):
|
||||
pytest.skip(
|
||||
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
# ----- Helpers -----
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _rtne_fp4(x: torch.Tensor):
|
||||
device = x.device
|
||||
grid = torch.tensor(
|
||||
[
|
||||
-6.0,
|
||||
-4.0,
|
||||
-3.0,
|
||||
-2.0,
|
||||
-1.5,
|
||||
-1.0,
|
||||
-0.5,
|
||||
-0.0,
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
grid_int = torch.tensor(
|
||||
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
inds = torch.bucketize(x, grid)
|
||||
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
|
||||
g_lo, g_hi = grid[lo], grid[hi]
|
||||
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
|
||||
y = torch.where(pick_hi, g_hi, g_lo)
|
||||
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
|
||||
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
|
||||
return y, y_int_packed
|
||||
|
||||
|
||||
def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float):
|
||||
device = x_e2m1.device
|
||||
|
||||
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
|
||||
x_e2m1_unpacked = torch.stack(
|
||||
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
grid_dq = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
],
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
x_fp4_dq = grid_dq[x_e2m1_unpacked]
|
||||
|
||||
scales_dq = x_e4m3.to(torch.float64)
|
||||
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten(
|
||||
start_dim=-2
|
||||
) / alpha # * (4. / 3.)
|
||||
return x_dq, x_fp4_dq, scales_dq
|
||||
|
||||
|
||||
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
|
||||
clip_mask_unpacked_dq = torch.zeros(
|
||||
*clip_mask.shape[:-1],
|
||||
clip_mask.size(-1) * 8,
|
||||
dtype=torch.bool,
|
||||
device=clip_mask.device,
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
|
||||
return clip_mask_unpacked_dq
|
||||
|
||||
|
||||
def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int):
|
||||
device = x.device
|
||||
|
||||
xh_ref64 = (
|
||||
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
|
||||
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1)
|
||||
scales_ref64_ = abs_max + 1e-8
|
||||
|
||||
xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn)
|
||||
scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64)
|
||||
xh_scaled_ref64 = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None]
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
xh_scaled_ref64 *= 6.0
|
||||
|
||||
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
|
||||
clip_mask_ref = torch.zeros(
|
||||
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
|
||||
|
||||
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
|
||||
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0)
|
||||
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
|
||||
|
||||
assert xh_fp4_dq.equal(xh_fp4_ref)
|
||||
assert scales_dq.equal(scales_ref64)
|
||||
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
|
||||
|
||||
return (
|
||||
xh_dq,
|
||||
clip_mask_unpacked_ref,
|
||||
(xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref),
|
||||
)
|
||||
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
DEVICE = torch.device("cuda:0")
|
||||
ROT_SIZES = [16, 32, 64, 128]
|
||||
GLOBAL_SCALES = [6.0]
|
||||
|
||||
LLAMA_MODELS = {
|
||||
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
|
||||
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
|
||||
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
|
||||
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _seed_each_test():
|
||||
current_platform.seed_everything(0)
|
||||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization(rot_size: int, global_scale_value: float):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
global_scale = torch.tensor([global_scale_value], device=device)
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size)
|
||||
xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale)
|
||||
xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1
|
||||
|
||||
m, n, k = 504, 4096 * 2, 4096
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
|
||||
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
|
||||
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
|
||||
)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
|
||||
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("batch", [1, 16])
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
m = batch
|
||||
k, n = LLAMA_MODELS[model][layer_idx]
|
||||
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
global_scale = torch.tensor([1.0], device=device)
|
||||
|
||||
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
|
||||
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
|
||||
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
|
||||
)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
32
tests/quantization/fp_quant.py
Normal file
32
tests/quantization/fp_quant.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test model set-up and inference for quantized HF models supported
|
||||
on the GPU backend using FPQuant.
|
||||
|
||||
Validating the configuration and printing results for manual checking.
|
||||
|
||||
Run `pytest tests/quantization/test_fp_quant.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
MODELS = [
|
||||
"ISTA-DASLab/Qwen3-0.6B-RTN-NVFP4",
|
||||
"ISTA-DASLab/Qwen3-0.6B-RTN-MXFP4",
|
||||
]
|
||||
DTYPE = ["bfloat16"]
|
||||
EAGER = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp_quant"),
|
||||
reason="FPQuant is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("eager", EAGER)
|
||||
def test_fpquant(vllm_runner, model, eager):
|
||||
with vllm_runner(model, enforce_eager=eager) as llm:
|
||||
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
|
||||
assert output[0][1] == "1 2 3 4 5 6"
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -2483,6 +2483,144 @@ def onednn_scaled_mm(
|
||||
return output
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
|
||||
|
||||
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
|
||||
def _fake_matmul_mxf4_bf16_tn(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
):
|
||||
return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def matmul_mxf4_bf16_tn(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"):
|
||||
|
||||
@register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn")
|
||||
def _fake_matmul_ada_mxf4_bf16_tn(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
):
|
||||
return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def matmul_ada_mxf4_bf16_tn(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"):
|
||||
|
||||
@register_fake("_qutlass_C::fusedQuantizeMxQuest")
|
||||
def _fake_fused_quantize_mx_quest(
|
||||
a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
|
||||
):
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"):
|
||||
|
||||
@register_fake("_qutlass_C::fusedQuantizeMxAbsMax")
|
||||
def _fake_fused_quantize_mx_absmax(
|
||||
a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
|
||||
):
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
def fusedQuantizeMx(
|
||||
a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if a.dim() == 0:
|
||||
raise ValueError("`a` must have at least 1 dimension.")
|
||||
if a.size(-1) % 32 != 0:
|
||||
raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.")
|
||||
if b.device != a.device:
|
||||
raise ValueError("`a` and `b` must be on the same device.")
|
||||
|
||||
xh_e2m1 = torch.empty(
|
||||
*a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
|
||||
)
|
||||
|
||||
rows, cols = a.numel() // a.size(-1), a.size(-1) // 32
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
xh_e8m0 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device
|
||||
)
|
||||
|
||||
if not hasattr(torch.ops, "_qutlass_C"):
|
||||
raise RuntimeError(
|
||||
"The `_qutlass_C` extension is not loaded. "
|
||||
"Make sure your custom op library is imported before calling fusedQuantizeMx."
|
||||
)
|
||||
|
||||
if method == "quest":
|
||||
return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0)
|
||||
elif method == "abs_max":
|
||||
return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0)
|
||||
else:
|
||||
raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'")
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"):
|
||||
|
||||
@register_fake("_qutlass_C::fusedQuantizeNv")
|
||||
def _fake_fused_quantize_nv(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
xh_e2m1: torch.Tensor,
|
||||
xh_e4m3: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
):
|
||||
return xh_e2m1, xh_e4m3
|
||||
|
||||
|
||||
def fusedQuantizeNv(
|
||||
a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xh_e2m1 = torch.empty(
|
||||
*a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
|
||||
)
|
||||
|
||||
rows, cols = a.numel() // a.size(-1), a.size(-1) // 16
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
xh_e4m3 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device
|
||||
)
|
||||
|
||||
return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale)
|
||||
|
||||
|
||||
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
|
||||
|
@ -12,6 +12,7 @@ QuantizationMethods = Literal[
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
@ -102,6 +103,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
from .fp8 import Fp8Config
|
||||
from .fp_quant import FPQuantConfig
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_bitblas import GPTQBitBLASConfig
|
||||
@ -125,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"tpu_int8": Int8TpuConfig,
|
||||
"fp8": Fp8Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"fp_quant": FPQuantConfig,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"bitblas": BitBLASConfig,
|
||||
|
420
vllm/model_executor/layers/quantization/fp_quant.py
Normal file
420
vllm/model_executor/layers/quantization/fp_quant.py
Normal file
@ -0,0 +1,420 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cutlass_scaled_fp4_mm,
|
||||
fusedQuantizeMx,
|
||||
fusedQuantizeNv,
|
||||
matmul_mxf4_bf16_tn,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class FPQuantConfig(QuantizationConfig):
|
||||
"""Config class for FPQuant."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hadamard_group_size: int = 32,
|
||||
forward_dtype: str = "mxfp4",
|
||||
forward_method: str = "abs_max",
|
||||
pseudoquantization: bool = False,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hadamard_group_size = hadamard_group_size
|
||||
self.forward_dtype = forward_dtype
|
||||
self.forward_method = forward_method
|
||||
self.pseudoquantization = pseudoquantization
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
if pseudoquantization:
|
||||
raise ValueError("Pseudoquantization is not supported for vLLM")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, "
|
||||
f"forward_dtype={self.forward_dtype}, "
|
||||
f"forward_method={self.forward_method}, "
|
||||
f"pseudoquantization={self.pseudoquantization}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fp_quant"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig":
|
||||
hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"])
|
||||
forward_dtype = cls.get_from_keys(config, ["forward_dtype"])
|
||||
forward_method = cls.get_from_keys(config, ["forward_method"])
|
||||
pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"])
|
||||
modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
return cls(
|
||||
hadamard_group_size,
|
||||
forward_dtype,
|
||||
forward_method,
|
||||
pseudoquantization,
|
||||
modules_to_not_convert,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[LinearMethodBase]:
|
||||
if self.modules_to_not_convert is not None and any(
|
||||
prefix.endswith(module) for module in self.modules_to_not_convert
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return FPQuantLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class FPQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for FPQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The FPQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FPQuantConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
del input_size # Unused.
|
||||
|
||||
if params_dtype != torch.bfloat16:
|
||||
raise ValueError("Only bfloat16 is currently supported by FPQuant")
|
||||
if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size. Or other skill issues."
|
||||
)
|
||||
|
||||
assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], (
|
||||
"Only mxfp4 and nvfp4 are supported for now"
|
||||
)
|
||||
if self.quant_config.forward_dtype == "mxfp4":
|
||||
group_size = 32
|
||||
elif self.quant_config.forward_dtype == "nvfp4":
|
||||
group_size = 16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported forward_dtype: {self.quant_config.forward_dtype}"
|
||||
)
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": 2,
|
||||
}
|
||||
| extra_weight_attrs,
|
||||
)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // group_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": group_size,
|
||||
}
|
||||
| extra_weight_attrs,
|
||||
)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
weight_global_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
weight_global_scale, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
act_global_scale = Parameter(
|
||||
torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
act_global_scale, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("act_global_scale", act_global_scale)
|
||||
|
||||
forward_hadamard_matrix = Parameter(
|
||||
torch.empty(
|
||||
self.quant_config.hadamard_group_size,
|
||||
self.quant_config.hadamard_group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix)
|
||||
|
||||
backward_hadamard_matrix = Parameter(
|
||||
torch.empty(
|
||||
self.quant_config.hadamard_group_size,
|
||||
self.quant_config.hadamard_group_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
|
||||
)
|
||||
layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return quantized_forward(
|
||||
x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.weight_global_scale,
|
||||
layer.act_global_scale,
|
||||
bias,
|
||||
layer.forward_hadamard_matrix,
|
||||
self.quant_config.forward_method,
|
||||
self.quant_config.forward_dtype,
|
||||
)
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def fused_quantize_mx(
|
||||
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
|
||||
|
||||
|
||||
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
|
||||
rows, cols = x_flat.size(0), x_flat.size(1) // 32
|
||||
padded_rows = ((rows + 128 - 1) // 128) * 128
|
||||
padded_cols = ((cols + 4 - 1) // 4) * 4
|
||||
|
||||
xh_e2m1 = torch.empty(
|
||||
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
|
||||
)
|
||||
xh_e8m0 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device
|
||||
)
|
||||
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_quantize_mx",
|
||||
op_func=fused_quantize_mx,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_quantize_mx_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def matmul_mxf4_bf16(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
xs: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return matmul_mxf4_bf16_tn(
|
||||
x,
|
||||
w,
|
||||
to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu),
|
||||
to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu),
|
||||
alpha,
|
||||
)
|
||||
|
||||
|
||||
def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha):
|
||||
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="matmul_mxf4_bf16",
|
||||
op_func=matmul_mxf4_bf16,
|
||||
mutates_args=[],
|
||||
fake_impl=matmul_mxf4_bf16_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def fused_quantize_nv(
|
||||
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale)
|
||||
|
||||
|
||||
def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale):
|
||||
rows, cols = x_flat.size(0), x_flat.size(1) // 16
|
||||
padded_rows = ((rows + 128 - 1) // 128) * 128
|
||||
padded_cols = ((cols + 4 - 1) // 4) * 4
|
||||
|
||||
xh_e2m1 = torch.empty(
|
||||
x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device
|
||||
)
|
||||
xh_e8m0 = torch.empty(
|
||||
padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device
|
||||
)
|
||||
|
||||
return xh_e2m1, xh_e8m0
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_quantize_nv",
|
||||
op_func=fused_quantize_nv,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_quantize_nv_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def matmul_nvf4_bf16(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
xs: torch.Tensor,
|
||||
ws: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return cutlass_scaled_fp4_mm(
|
||||
x,
|
||||
w,
|
||||
to_blocked(xs, backend="triton")
|
||||
.view(torch.float8_e4m3fn)
|
||||
.view(-1, x.shape[1] // 8), # *2//16
|
||||
to_blocked(ws, backend="triton")
|
||||
.view(torch.float8_e4m3fn)
|
||||
.view(-1, x.shape[1] // 8),
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha):
|
||||
return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="matmul_nvf4_bf16",
|
||||
op_func=matmul_nvf4_bf16,
|
||||
mutates_args=[],
|
||||
fake_impl=matmul_nvf4_bf16_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def quantized_forward(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
weight_global_scale: torch.Tensor,
|
||||
act_global_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
forward_hadamard_matrix: torch.Tensor,
|
||||
forward_method: str,
|
||||
forward_dtype: str,
|
||||
) -> torch.Tensor:
|
||||
x_flat = x.contiguous().flatten(end_dim=-2)
|
||||
|
||||
if forward_dtype == "mxfp4":
|
||||
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx(
|
||||
x_flat, forward_hadamard_matrix, forward_method
|
||||
)
|
||||
y = torch.ops.vllm.matmul_mxf4_bf16(
|
||||
x_flat_q,
|
||||
qweight,
|
||||
x_flat_scales,
|
||||
weight_scales,
|
||||
1 / (weight_global_scale * act_global_scale),
|
||||
)
|
||||
elif forward_dtype == "nvfp4":
|
||||
x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv(
|
||||
x_flat, forward_hadamard_matrix, act_global_scale
|
||||
)
|
||||
y = torch.ops.vllm.matmul_nvf4_bf16(
|
||||
x_flat_q,
|
||||
qweight,
|
||||
x_flat_scales,
|
||||
weight_scales,
|
||||
1 / (weight_global_scale * act_global_scale),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported forward_dtype: {forward_dtype}")
|
||||
|
||||
y = y.view(*x.shape[:-1], y.shape[-1])
|
||||
if bias is not None:
|
||||
y += bias
|
||||
|
||||
return y
|
185
vllm/model_executor/layers/quantization/qutlass_utils.py
Normal file
185
vllm/model_executor/layers/quantization/qutlass_utils.py
Normal file
@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
#
|
||||
# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.library import wrap_triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_scale_swizzle(
|
||||
scale_ptr: torch.Tensor,
|
||||
scale_rows: int,
|
||||
scale_cols: int,
|
||||
output_ptr: torch.Tensor,
|
||||
input_row_stride: int,
|
||||
output_block_stride: int,
|
||||
BLOCK_ROWS: tl.constexpr,
|
||||
BLOCK_COLS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Rearranges tensor data from row-major to block-scaled swizzle format.
|
||||
|
||||
Args:
|
||||
scale_ptr: Pointer to the input scale tensor
|
||||
scale_rows: Number of rows in the scale tensor
|
||||
scale_cols: Number of columns in the scale tensor
|
||||
output_ptr: Pointer to the output tensor
|
||||
input_row_stride: Stride between rows in the input tensor
|
||||
output_block_stride: Stride between blocks in the output tensor
|
||||
BLOCK_ROWS: Number of rows in a tile (compile-time constant)
|
||||
BLOCK_COLS: Number of columns in a tile (compile-time constant)
|
||||
"""
|
||||
pid_row = tl.program_id(0)
|
||||
pid_col = tl.program_id(1)
|
||||
|
||||
rows = tl.arange(0, BLOCK_ROWS)[:, None]
|
||||
cols = tl.arange(0, BLOCK_COLS)[None, :]
|
||||
|
||||
# Calculate starting row and column for this tile
|
||||
start_row = pid_row * BLOCK_ROWS
|
||||
start_col = pid_col * BLOCK_COLS
|
||||
global_rows = start_row + rows
|
||||
global_cols = start_col + cols
|
||||
|
||||
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
|
||||
|
||||
input_scales = tl.load(
|
||||
scale_ptr + global_rows * input_row_stride + global_cols,
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
r_div_32 = rows // 32
|
||||
r_mod_32 = rows % 32
|
||||
|
||||
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
|
||||
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
|
||||
|
||||
# Flatten
|
||||
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
|
||||
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
|
||||
|
||||
# Calculate block offset using provided output block stride
|
||||
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
|
||||
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
|
||||
|
||||
tl.store(
|
||||
output_ptr + block_offset + dest_indices_flat,
|
||||
scales_flat,
|
||||
)
|
||||
|
||||
|
||||
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rearranges an E8M0 tensor scale from row-major format to
|
||||
block-scaled swizzle format.
|
||||
|
||||
This format is suitable for Tmem as described in NVIDIA documentation:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
scale_tensor: Input tensor in row-major format with 8-bit elements
|
||||
|
||||
Returns:
|
||||
Rearranged tensor in block-scaled swizzle format
|
||||
"""
|
||||
assert scale_tensor.element_size() == 1, (
|
||||
"Expected element size to be 1 byte (8 bits)"
|
||||
)
|
||||
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
|
||||
|
||||
rows, cols = scale_tensor.shape
|
||||
|
||||
# Calculate blocks needed
|
||||
n_row_blocks = triton.cdiv(rows, 128)
|
||||
n_col_blocks = triton.cdiv(cols, 4)
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
out = scale_tensor.new_empty((padded_rows, padded_cols))
|
||||
|
||||
# Input stride (for row-major format)
|
||||
input_row_stride = cols
|
||||
|
||||
# We probably want handle multiple blocks per tile but
|
||||
# for now keep it simple
|
||||
BLOCK_ROWS, BLOCK_COLS = 128, 4
|
||||
|
||||
# Output block stride for the rearranged format
|
||||
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(padded_rows, BLOCK_ROWS),
|
||||
triton.cdiv(padded_cols, BLOCK_COLS),
|
||||
)
|
||||
|
||||
wrap_triton(triton_scale_swizzle)[grid](
|
||||
scale_tensor.view(torch.uint8),
|
||||
rows,
|
||||
cols,
|
||||
out.view(torch.uint8),
|
||||
input_row_stride,
|
||||
output_block_stride,
|
||||
BLOCK_ROWS=BLOCK_ROWS,
|
||||
BLOCK_COLS=BLOCK_COLS,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def to_blocked(
|
||||
input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying
|
||||
the rearrangement pattern.
|
||||
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
backend: "torch" (PyTorch path) or "triton" (Triton kernel)
|
||||
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
if backend == "triton":
|
||||
return triton_mx_block_rearrange(input_matrix).flatten()
|
||||
elif backend != "torch":
|
||||
raise ValueError(f'backend must be "torch" or "triton", got {backend!r}')
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
assert (rows, cols) == (padded_rows, padded_cols)
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
|
||||
return rearranged.flatten()
|
Reference in New Issue
Block a user