[Kernel][B200] mxfp4 fused cutlass moe (#23696)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Duncan Moss
2025-09-11 14:04:56 -07:00
committed by GitHub
parent 79ac59f32e
commit 074854b24f
5 changed files with 622 additions and 60 deletions

View File

@ -11,6 +11,7 @@ import torch
from packaging import version
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
@ -19,6 +20,10 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer())
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe(
transpose_optimized=transpose_optimized)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales on the last dimension by groups of 4, matching
the transformation in mxfp4.py's BF16 (Hopper) path."""
s = scales.to(torch.uint8)
s_shape = s.shape
assert s_shape[-1] % 4 == 0
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
# Move the 4-group dimension before the row dimension
permuted = s.permute(0, 2, 1, 3)
# Merge the row dim with the 4-group dim
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not HOPPER_MXFP4_BF16_AVAILABLE,
reason="nvidia gpu sm90 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
):
torch.manual_seed(42)
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
w13_q = torch.randint(
0,
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
device=device,
dtype=torch.uint8)
w13_scale = torch.randint(
118,
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
device=device,
dtype=torch.uint8)
w2_q = torch.randint(0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8)
w2_scale = torch.randint(
118,
123, (num_experts, hidden_size, intermediate_size // 32),
device=device,
dtype=torch.uint8)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
num_experts, 2 * intermediate_size, hidden_size)
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
num_experts, hidden_size, intermediate_size)
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
w13_s = torch.cat([w3_s, w1_s], dim=1)
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
fc1_expert_weights=w13_q_swapped,
fc2_expert_weights=w2_q,
output_dtype=torch.bfloat16,
output=out,
quant_scales=[w13_s_inter.to(torch.uint8),
w2_s_inter.to(torch.uint8)],
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha,
swiglu_beta=beta,
swiglu_limit=limit,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
use_w4_group_scaling=True,
)
# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not (current_platform.is_cuda()
and current_platform.is_device_capability(100) and has_flashinfer()),
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: Optional[float],
beta: Optional[float],
limit: Optional[float],
):
torch.manual_seed(42)
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
# Float weights in w13 format [w1; w3]
w13 = (torch.randn(num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16) / 10)
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16) / 10)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
# Quantize weights to MXFP4 per expert (SM100 path)
from flashinfer import mxfp4_quantize
def quant_mxfp4_batches(a: torch.Tensor, e: int):
qs, sfs = [], []
for i in range(e):
q, sf = mxfp4_quantize(a[i].cuda())
qs.append(q)
sfs.append(sf)
return torch.stack(qs), torch.stack(sfs)
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor):
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
from flashinfer import mxfp4_dequantize
return torch.stack([
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
])
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
# Reference result using dequantized tensors and reference_moe
w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size)
w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
# Prepare inputs for FlashInfer CUTLASS fused MoE
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
# Swap scales halves to match swapped weights
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
# Build routing for kernel
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha_t = torch.full((num_experts, ),
alpha,
device=hidden_states.device)
else:
alpha_t = None
if beta is not None:
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
else:
beta_t = None
if limit is not None:
limit_t = torch.full((num_experts, ),
limit,
device=hidden_states.device)
else:
limit_t = None
# Quant scales for SM100 MXFP8+MXFP4 path
fake_input_scale = torch.ones(num_experts, device=device)
quant_scales = [
w13_scale_swapped.view(torch.int32),
fake_input_scale,
w2_scale.view(torch.int32),
fake_input_scale,
]
_ = flashinfer_cutlass_fused_moe(
input=hidden_states_q,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
fc2_expert_weights=w2_q.contiguous().view(torch.long),
output_dtype=torch.bfloat16,
output=out,
quant_scales=quant_scales,
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha_t,
swiglu_beta=beta_t,
swiglu_limit=limit_t,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
use_mxfp8_act_scaling=True,
input_sf=hidden_states_sf,
)
# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)

View File

@ -166,7 +166,8 @@ if TYPE_CHECKING:
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
@ -1004,6 +1005,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
# If set to 1, use the FlashInfer CUTLASS backend for
# MXFP8 (activation) x MXFP4 (weight) MoE.
# This is separate from the TRTLLMGEN path controlled by
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS":
lambda: bool(int(
os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")
)),
# If set to 1, use the FlashInfer
# BF16 (activation) x MXFP4 (weight) MoE backend.
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
@ -1296,6 +1306,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_FP8",
"VLLM_USE_FLASHINFER_MOE_FP4",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",

View File

@ -813,9 +813,16 @@ class FusedMoE(CustomOp):
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Callable, Optional, Union
import torch
@ -33,33 +34,72 @@ from vllm.utils.flashinfer import has_flashinfer
logger = init_logger(__name__)
def _should_use_flashinfer_mxfp4_bf16():
"""Determine if FlashInfer MXFP4 BF16 should be used."""
# If explicitly set, respect the setting
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
NONE = 0
# Enable by default on SM100 if MXFP8 is not explicitly enabled
if (current_platform.is_device_capability(100) and has_flashinfer()
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
logger.info_once(
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
"For faster performance, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
"though this may impact accuracy.")
return True
# FlashInfer Backend
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
SM100_FI_MXFP4_BF16 = 3
SM90_FI_MXFP4_BF16 = 4
return False
# Marlin Backend
MARLIN = 5
# Triton Backend
TRITON = 6
def _should_use_flashinfer_mxfp4_mxfp8():
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
def get_mxfp4_backend():
# Backend Selection
if current_platform.is_cuda():
if (current_platform.is_device_capability(90) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
return Mxfp4Backend.SM90_FI_MXFP4_BF16
elif (current_platform.is_device_capability(100) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
elif (current_platform.is_device_capability(100) and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
"for high concurrency throughput workloads consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
"performance")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability(100) and has_flashinfer():
logger.info_once(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
"accuracy.")
return Mxfp4Backend.SM100_FI_MXFP4_BF16
elif ((current_platform.is_device_capability(100)
or current_platform.is_device_capability(90))
and not has_flashinfer()):
logger.warning_once(
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
# If FlashInfer is not available, try either Marlin or Triton
if current_platform.get_device_capability(
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
"2.8.0"):
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
def should_use_flashinfer_mxfp4():
return (_should_use_flashinfer_mxfp4_mxfp8()
or _should_use_flashinfer_mxfp4_bf16())
return Mxfp4Backend.NONE
class Mxfp4Config(QuantizationConfig):
@ -113,31 +153,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.use_marlin = self._should_use_marlin()
self.mxfp4_backend = get_mxfp4_backend()
self.max_capture_size = get_current_vllm_config(
).compilation_config.max_capture_size
if current_platform.is_device_capability(100) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
"Please check your environment and try again.")
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN
if current_platform.is_cuda() and \
not current_platform.is_device_capability(100):
if not current_platform.has_device_capability(90):
# marlin kernel has better performance on ampere
return True
if not has_triton_kernels():
return True
if not is_torch_equal_or_newer("2.8.0"):
return True
return False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -157,7 +181,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = \
intermediate_size_per_partition
if self.use_marlin:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
# The moe marlin kernel requires that for each linear
# n % 256 == 0 and k % 128 == 0.
# In gate_up_proj:
@ -175,16 +199,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad
elif should_use_flashinfer_mxfp4():
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm():
elif current_platform.is_rocm() or (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 128)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
@ -264,9 +292,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if self.use_marlin:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer)
elif should_use_flashinfer_mxfp4():
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
from flashinfer.fp4_quantization import (
nvfp4_block_scale_interleave)
from flashinfer.fused_moe.core import (
@ -429,7 +458,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
else:
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_beta = Parameter(torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_clamp_limit = Parameter(torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
sf_block_size = 32 # mxfp4 block size
# Common shape assertions
assert (layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2)
assert (layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1]
== self.intermediate_size * 2
and layer.w13_weight_scale.shape[2]
== self.hidden_size // sf_block_size)
assert (layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and
layer.w2_weight.shape[2] == self.intermediate_size // 2)
assert (layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size)
assert (layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
assert (layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size)
# De-interleave and swap for w13 weight, bias, and scales
w13_w = layer.w13_weight.data
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
w13_b = layer.w13_bias.data.to(torch.float32)
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w13_s = layer.w13_weight_scale.data
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import block_scale_interleave
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)).reshape(orig_shape)
w2_s = layer.w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)).reshape(orig_shape)
layer.w13_weight = Parameter(w13_weight_swapped,
requires_grad=False)
layer.w13_weight_scale = Parameter(w13_scale_interleaved,
requires_grad=False)
layer.w13_bias = Parameter(w13_bias_swapped,
requires_grad=False)
layer.w2_weight_scale = Parameter(w2_scale_interleaved,
requires_grad=False)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(w_shape[0], w_shape[1],
(w_shape[2] // 4), 4)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8).view(
torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
w31_scales)
w2_weight_scale = layer.w2_weight_scale.data
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(
w2_scales)
layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w],
dim=1),
requires_grad=False)
layer.w13_bias = torch.nn.Parameter(w13_bias_swapped,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w31_scales_interleaved, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
@ -464,6 +602,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
@ -500,7 +640,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
else:
if should_use_flashinfer_mxfp4():
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
@ -601,7 +742,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.use_marlin:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -665,16 +806,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_replica_count), (
"MXFP4 are not supported with this configuration.")
if should_use_flashinfer_mxfp4():
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
if _should_use_flashinfer_mxfp4_bf16():
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
from flashinfer import trtllm_fp4_block_scale_moe
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
@ -706,7 +850,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=self.max_capture_size,
)[0]
return trtllm_gen_output
else:
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts,
device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(
torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(
torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
_ = flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=self.max_capture_size,
**extra_kwargs,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward)
return triton_kernel_moe_forward(
@ -724,3 +947,5 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")

View File

@ -33,8 +33,8 @@ def kernel_warmup(worker: "Worker"):
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer kernel autotune for Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.is_device_capability(100):
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup