[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -655,6 +655,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
||||
|
@ -15,7 +15,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||
Fp8LinearOp, maybe_create_device_identity)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
@ -26,9 +26,9 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
||||
cutlass_fp8_enabled: bool, *args, **kwargs):
|
||||
force_fp8_e4m3fnuz: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
@ -43,7 +43,7 @@ class TestModel(torch.nn.Module):
|
||||
for _ in range(2)
|
||||
]
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
@ -81,12 +81,11 @@ class TestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
cutlass_fp8_enabled):
|
||||
force_fp8_e4m3fnuz):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
@ -103,7 +102,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
||||
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
|
@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
|
||||
use_per_token_if_dynamic=False)
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
|
@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
@ -20,7 +20,7 @@ from .backend import TestBackend
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
|
||||
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -32,7 +32,7 @@ class TestModel(torch.nn.Module):
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
@ -48,12 +48,11 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
cutlass_fp8_enabled):
|
||||
force_fp8_e4m3fnuz):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, cutlass_fp8_enabled)
|
||||
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
73
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
73
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason=
|
||||
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp8_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
use_bias: bool,
|
||||
seed: int,
|
||||
device: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, k = shape
|
||||
a = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b = torch.randn((n, k), dtype=dtype, device=device) / k
|
||||
|
||||
a_fp8, a_scale = ops.scaled_fp8_quant(a)
|
||||
b_fp8, b_scale = ops.scaled_fp8_quant(b)
|
||||
|
||||
expected_out = torch.mm(
|
||||
a_scale * a_fp8.to(dtype=torch.float32),
|
||||
b_scale * b_fp8.to(dtype=torch.float32).t(),
|
||||
).to(dtype=dtype)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn((n, ), dtype=dtype, device=device)
|
||||
expected_out = expected_out + bias
|
||||
else:
|
||||
bias = None
|
||||
|
||||
import flashinfer
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp8_mm(
|
||||
a_fp8,
|
||||
b_fp8.t(),
|
||||
a_scale,
|
||||
b_scale,
|
||||
dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)
|
@ -223,8 +223,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_fp8_supported=cutlass_fp8_supported())
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -376,6 +375,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
# layer.input_scale is None indicates dynamic quant and scale is
|
||||
# computed from input.
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
|
@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
cutlass_fp8_supported=False,
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN,
|
||||
force_fp8_e4m3fnuz=True)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
|
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
@ -157,6 +158,19 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: list, **kwargs) -> torch.Tensor:
|
||||
|
||||
return flashinfer_scaled_fp8_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
@ -231,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
input_2d: torch.Tensor, output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
@ -303,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
cutlass_fp8_supported: bool, per_tensor_weights: bool,
|
||||
preferred_backend: str, per_tensor_weights: bool,
|
||||
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if cutlass_fp8_supported:
|
||||
return cutlass_w8a8_scaled_mm
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if current_platform.is_rocm():
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if not per_tensor_weights and not per_tensor_activations \
|
||||
and USE_ROWWISE_TORCH_SCALED_MM:
|
||||
@ -334,10 +354,20 @@ class Fp8LinearOp:
|
||||
|
||||
def __init__(self,
|
||||
act_quant_static: bool,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: Optional[bool] = None):
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported
|
||||
pad_output: Optional[bool] = None,
|
||||
force_fp8_e4m3fnuz: bool = False):
|
||||
if current_platform.is_rocm():
|
||||
self.preferred_backend = "rocm"
|
||||
elif current_platform.is_cuda(
|
||||
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
|
||||
if has_flashinfer() and current_platform.has_device_capability(
|
||||
100):
|
||||
self.preferred_backend = "flashinfer"
|
||||
else:
|
||||
self.preferred_backend = "cutlass"
|
||||
else:
|
||||
self.preferred_backend = "torch"
|
||||
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
@ -347,8 +377,7 @@ class Fp8LinearOp:
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE and \
|
||||
not cutlass_fp8_supported and \
|
||||
not current_platform.is_rocm()
|
||||
self.preferred_backend == "torch"
|
||||
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.act_quant_static = act_quant_static
|
||||
@ -393,9 +422,9 @@ class Fp8LinearOp:
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.cutlass_fp8_supported, per_tensor_weights,
|
||||
per_tensor_activations)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend,
|
||||
per_tensor_weights,
|
||||
per_tensor_activations)
|
||||
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
|
@ -265,6 +265,37 @@ if has_flashinfer():
|
||||
dtype=dtype,
|
||||
device=A.device)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::bmm_fp8",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import bmm_fp8 as bmm_fp8_
|
||||
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
||||
|
||||
@torch.library.register_fake("vllm::bmm_fp8", )
|
||||
def bmm_fp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0],
|
||||
A.shape[1],
|
||||
B.shape[2],
|
||||
dtype=dtype,
|
||||
device=A.device)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
@ -293,6 +324,35 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp8_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert a.shape[1] == b.shape[0]
|
||||
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
||||
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
||||
assert a.device.type == "cuda" and b.device.type == "cuda"
|
||||
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
||||
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
||||
|
||||
output = bmm_fp8(
|
||||
a.unsqueeze(0),
|
||||
b.unsqueeze(0),
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype,
|
||||
"auto",
|
||||
).view(a.shape[0], b.shape[1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@ -307,4 +367,5 @@ __all__ = [
|
||||
"supports_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
|
Reference in New Issue
Block a user