Use Blackwell FlashInfer MXFP4 MoE by default if available (#23008)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-08-18 18:25:49 -04:00
committed by GitHub
parent ac6eb49de3
commit 6d25e3fd6e
2 changed files with 51 additions and 19 deletions

View File

@ -762,11 +762,11 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works # we padding globally so EP buffer allocation works
if (quant_config and quant_config.get_name() == "mxfp4" if quant_config and quant_config.get_name() == "mxfp4":
and (current_platform.is_rocm() from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 should_use_flashinfer_mxfp4)
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config

View File

@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
@ -26,12 +27,38 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up) next_power_of_2, round_up)
from vllm.utils.flashinfer import has_flashinfer
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 logger = init_logger(__name__)
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (mxfp8_quantize, shuffle_matrix_a, def _should_use_flashinfer_mxfp4_bf16():
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) """Determine if FlashInfer MXFP4 BF16 should be used."""
# If explicitly set, respect the setting
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
# Enable by default on SM100 if MXFP8 is not explicitly enabled
if (current_platform.is_device_capability(100) and has_flashinfer()
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
logger.info_once(
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
"For faster performance, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
"though this may impact accuracy.")
return True
return False
def _should_use_flashinfer_mxfp4_mxfp8():
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
def should_use_flashinfer_mxfp4():
return (_should_use_flashinfer_mxfp4_mxfp8()
or _should_use_flashinfer_mxfp4_bf16())
class Mxfp4Config(QuantizationConfig): class Mxfp4Config(QuantizationConfig):
@ -87,12 +114,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.moe = moe self.moe = moe
self.use_marlin = self._should_use_marlin() self.use_marlin = self._should_use_marlin()
if current_platform.is_device_capability(100) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
def _should_use_marlin(self): def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None: if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN return envs.VLLM_MXFP4_USE_MARLIN
if current_platform.is_cuda() and \ if current_platform.is_cuda() and \
not current_platform.has_device_capability(100): not current_platform.is_device_capability(100):
if not current_platform.is_device_capability(90): if not current_platform.has_device_capability(90):
# marlin kernel has better performance on ampere # marlin kernel has better performance on ampere
return True return True
if not has_triton_kernels(): if not has_triton_kernels():
@ -138,8 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.hidden_size = hidden_size layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \ layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad intermediate_size_per_partition_after_pad
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 elif should_use_flashinfer_mxfp4():
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# pad the intermediate size to be a multiple of 2 * mxfp4_block # pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling # for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance # other padding to increase performance
@ -230,8 +262,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.use_marlin: if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer)
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 elif should_use_flashinfer_mxfp4():
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
layer.gemm1_alpha = Parameter(torch.tensor( layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(), [1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False) requires_grad=False)
@ -478,11 +510,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_replica_count), ( logical_replica_count), (
"MXFP4 are not supported with this configuration.") "MXFP4 are not supported with this configuration.")
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 if should_use_flashinfer_mxfp4():
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
assert not self.moe.use_ep, ( assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.") "EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: if _should_use_flashinfer_mxfp4_bf16():
assert x.dtype == torch.bfloat16 assert x.dtype == torch.bfloat16
x_quant = x x_quant = x
x_scale = None x_scale = None