Use Blackwell FlashInfer MXFP4 MoE by default if available (#23008)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-08-18 18:25:49 -04:00
committed by GitHub
parent ac6eb49de3
commit 6d25e3fd6e
2 changed files with 51 additions and 19 deletions

View File

@ -762,10 +762,10 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works
if (quant_config and quant_config.get_name() == "mxfp4"
and (current_platform.is_rocm()
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op

View File

@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
@ -26,12 +27,38 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up)
from vllm.utils.flashinfer import has_flashinfer
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
logger = init_logger(__name__)
def _should_use_flashinfer_mxfp4_bf16():
"""Determine if FlashInfer MXFP4 BF16 should be used."""
# If explicitly set, respect the setting
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
# Enable by default on SM100 if MXFP8 is not explicitly enabled
if (current_platform.is_device_capability(100) and has_flashinfer()
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
logger.info_once(
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
"For faster performance, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
"though this may impact accuracy.")
return True
return False
def _should_use_flashinfer_mxfp4_mxfp8():
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
def should_use_flashinfer_mxfp4():
return (_should_use_flashinfer_mxfp4_mxfp8()
or _should_use_flashinfer_mxfp4_bf16())
class Mxfp4Config(QuantizationConfig):
@ -87,12 +114,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.moe = moe
self.use_marlin = self._should_use_marlin()
if current_platform.is_device_capability(100) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN
if current_platform.is_cuda() and \
not current_platform.has_device_capability(100):
if not current_platform.is_device_capability(90):
not current_platform.is_device_capability(100):
if not current_platform.has_device_capability(90):
# marlin kernel has better performance on ampere
return True
if not has_triton_kernels():
@ -138,8 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
elif should_use_flashinfer_mxfp4():
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
@ -230,8 +262,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
elif should_use_flashinfer_mxfp4():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
@ -478,11 +510,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_replica_count), (
"MXFP4 are not supported with this configuration.")
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
if should_use_flashinfer_mxfp4():
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
if _should_use_flashinfer_mxfp4_bf16():
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None