diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c3c6e47827..4924f1fadb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -762,11 +762,11 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts # we padding globally so EP buffer allocation works - if (quant_config and quant_config.get_name() == "mxfp4" - and (current_platform.is_rocm() - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): - hidden_size = round_up(hidden_size, 256) + if quant_config and quant_config.get_name() == "mxfp4": + from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 + should_use_flashinfer_mxfp4) + if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 3c5d83037c..6a190ebbc0 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,6 +6,7 @@ import torch from torch.nn.parameter import Parameter from vllm import envs +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( @@ -26,12 +27,38 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, next_power_of_2, round_up) +from vllm.utils.flashinfer import has_flashinfer -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): - # from flashinfer.fused_moe import cutlass_fused_moe - from flashinfer import (mxfp8_quantize, shuffle_matrix_a, - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) +logger = init_logger(__name__) + + +def _should_use_flashinfer_mxfp4_bf16(): + """Determine if FlashInfer MXFP4 BF16 should be used.""" + # If explicitly set, respect the setting + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + + # Enable by default on SM100 if MXFP8 is not explicitly enabled + if (current_platform.is_device_capability(100) and has_flashinfer() + and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): + logger.info_once( + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " + "For faster performance, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " + "though this may impact accuracy.") + return True + + return False + + +def _should_use_flashinfer_mxfp4_mxfp8(): + """Determine if FlashInfer MXFP4 MXFP8 should be used.""" + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + + +def should_use_flashinfer_mxfp4(): + return (_should_use_flashinfer_mxfp4_mxfp8() + or _should_use_flashinfer_mxfp4_bf16()) class Mxfp4Config(QuantizationConfig): @@ -87,12 +114,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.moe = moe self.use_marlin = self._should_use_marlin() + if current_platform.is_device_capability(100) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + def _should_use_marlin(self): if envs.VLLM_MXFP4_USE_MARLIN is not None: return envs.VLLM_MXFP4_USE_MARLIN if current_platform.is_cuda() and \ - not current_platform.has_device_capability(100): - if not current_platform.is_device_capability(90): + not current_platform.is_device_capability(100): + if not current_platform.has_device_capability(90): # marlin kernel has better performance on ampere return True if not has_triton_kernels(): @@ -138,8 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + elif should_use_flashinfer_mxfp4(): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance @@ -230,8 +262,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + elif should_use_flashinfer_mxfp4(): + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -478,11 +510,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if should_use_flashinfer_mxfp4(): + from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe assert not self.moe.use_ep, ( "EP is not supported for flashinfer mxfp4 moe backend yet.") - if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: + if _should_use_flashinfer_mxfp4_bf16(): assert x.dtype == torch.bfloat16 x_quant = x x_scale = None