[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
nvjullin
2025-08-26 21:54:04 +08:00
committed by GitHub
parent b78bed1bc5
commit f66673a39d
9 changed files with 198 additions and 36 deletions

View File

@ -655,6 +655,7 @@ steps:
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform
from .backend import TestBackend
@ -26,9 +26,9 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool,
cutlass_fp8_enabled: bool, *args, **kwargs):
force_fp8_e4m3fnuz: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
@ -43,7 +43,7 @@ class TestModel(torch.nn.Module):
for _ in range(2)
]
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
act_quant_static=static,
act_quant_group_shape=group_shape,
)
@ -81,12 +81,11 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
force_fp8_e4m3fnuz):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
@ -103,7 +102,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
fusion_pass = FusionPass.instance(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)

View File

@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend
@ -20,7 +20,7 @@ from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
@ -32,7 +32,7 @@ class TestModel(torch.nn.Module):
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
@ -48,12 +48,11 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
force_fp8_e4m3fnuz):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, cutlass_fp8_enabled)
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size * 2)

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason=
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ["cuda:0"]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("autotune", [False, True])
@torch.inference_mode()
def test_flashinfer_fp8_gemm(
dtype: torch.dtype,
shape: tuple[int, int, int],
use_bias: bool,
seed: int,
device: str,
autotune: bool,
) -> None:
current_platform.seed_everything(seed)
m, n, k = shape
a = torch.randn((m, k), dtype=dtype, device=device)
b = torch.randn((n, k), dtype=dtype, device=device) / k
a_fp8, a_scale = ops.scaled_fp8_quant(a)
b_fp8, b_scale = ops.scaled_fp8_quant(b)
expected_out = torch.mm(
a_scale * a_fp8.to(dtype=torch.float32),
b_scale * b_fp8.to(dtype=torch.float32).t(),
).to(dtype=dtype)
if use_bias:
bias = torch.randn((n, ), dtype=dtype, device=device)
expected_out = expected_out + bias
else:
bias = None
import flashinfer
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp8_mm(
a_fp8,
b_fp8.t(),
a_scale,
b_scale,
dtype,
bias=bias,
)
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)

View File

@ -223,8 +223,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
cutlass_fp8_supported=cutlass_fp8_supported())
act_quant_group_shape=self.act_q_group_shape)
def create_weights(
self,
@ -376,6 +375,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
# layer.input_scale is None indicates dynamic quant and scale is
# computed from input.
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N

View File

@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
cutlass_fp8_supported=False,
act_quant_group_shape=GroupShape.PER_TOKEN)
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@ -157,6 +158,19 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
return output.view(*output_shape)
def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
@ -231,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
input_2d: torch.Tensor, output_shape: list,
**kwargs) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
@ -303,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
preferred_backend: str, per_tensor_weights: bool,
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if not per_tensor_weights and not per_tensor_activations \
and USE_ROWWISE_TORCH_SCALED_MM:
@ -334,10 +354,20 @@ class Fp8LinearOp:
def __init__(self,
act_quant_static: bool,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None):
self.cutlass_fp8_supported = cutlass_fp8_supported
pad_output: Optional[bool] = None,
force_fp8_e4m3fnuz: bool = False):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda(
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(
100):
self.preferred_backend = "flashinfer"
else:
self.preferred_backend = "cutlass"
else:
self.preferred_backend = "torch"
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
@ -347,8 +377,7 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE and \
not cutlass_fp8_supported and \
not current_platform.is_rocm()
self.preferred_backend == "torch"
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
@ -393,9 +422,9 @@ class Fp8LinearOp:
per_tensor_activations = (x_scale.numel() == 1)
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend,
per_tensor_weights,
per_tensor_activations)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,

View File

@ -265,6 +265,37 @@ if has_flashinfer():
dtype=dtype,
device=A.device)
@torch.library.custom_op(
"vllm::bmm_fp8",
mutates_args=[],
device_types="cuda",
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
from flashinfer import bmm_fp8 as bmm_fp8_
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
@torch.library.register_fake("vllm::bmm_fp8", )
def bmm_fp8_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
return torch.empty(A.shape[0],
A.shape[1],
B.shape[2],
dtype=dtype,
device=A.device)
def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
@ -293,6 +324,35 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
)
def flashinfer_scaled_fp8_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[0]
assert scale_a.numel() == 1 and scale_b.numel() == 1
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
assert a.device.type == "cuda" and b.device.type == "cuda"
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
output = bmm_fp8(
a.unsqueeze(0),
b.unsqueeze(0),
scale_a,
scale_b,
out_dtype,
"auto",
).view(a.shape[0], b.shape[1])
if bias is not None:
output = output + bias
return output
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@ -307,4 +367,5 @@ __all__ = [
"supports_trtllm_attention",
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]