diff --git a/vllm/envs.py b/vllm/envs.py index f8a7197dd1..8a3eb8e509 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -154,6 +154,8 @@ if TYPE_CHECKING: VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False def get_default_cache_root(): @@ -932,6 +934,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), + # If set to 1, use the FlashInfer + # MXFP8 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + + # If set to 1, use the FlashInfer + # BF16 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f155a1b11f..a4a6157fa4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -33,7 +33,8 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, + round_up) from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): @@ -719,6 +720,12 @@ class FusedMoE(torch.nn.Module): self.global_num_experts = num_experts + num_redundant_experts + # we padding globally so EP buffer allocation works + if quant_config and quant_config.get_name() == "mxfp4" and ( + envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + hidden_size = round_up(hidden_size, 256) + # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: @@ -1064,6 +1071,18 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, return_success: bool = False) -> Optional[bool]: + + if self.quant_config and self.quant_config.get_name() == "mxfp4": + # (FIXME) for gpt-oss all experts are combined + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: # Failed to load this param since it's not local to this rank @@ -1476,13 +1495,20 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad(hidden_states, + (0, self.hidden_size - og_hidden_states), + mode='constant', + value=0.0) # TODO: Once the OOM issue for the TPU backend is resolved, we will # switch to using the moe_forward custom op. if current_platform.is_tpu(): return self.forward_impl(hidden_states, router_logits) else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward( + hidden_states, router_logits, + self.layer_name)[..., :og_hidden_states] def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 95aea912a1..8d63027e18 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -37,6 +37,7 @@ QuantizationMethods = Literal[ "auto-round", "rtn", "inc", + "mxfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config + from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig @@ -148,6 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": AutoRoundConfig, "rtn": RTNConfig, "inc": INCConfig, + "mxfp4": Mxfp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 0000000000..b6d7bc5d5c --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import envs +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + _can_support_mxfp4) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import next_power_of_2, round_up + +if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + # from flashinfer.fused_moe import cutlass_fused_moe + from flashinfer import (mxfp8_quantize, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + + +class Mxfp4Config(QuantizationConfig): + + def __init__(self, ignored_layers: Optional[list[str]] = None): + super().__init__() + self.ignored_layers = ignored_layers + + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError( + "Mxfp4 attention layer is not implemented") + return None + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.topk_indices_dtype = None + self.moe = moe + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + + # FIXME (zyongye): ship after torch and safetensors support mxfp4 + # is_torch_mxfp4_available = ( + # hasattr(torch, "float4_e2m1fn_x2") and + # hasattr(torch, "float8_e8m0fnu")) + # if is_torch_mxfp4_available: + # weight_dtype = torch.float4_e2m1fn_x2 + # scale_dtype = torch.float8_e8m0fnu + + mxfp4_block = 32 + + intermediate_size_per_partition_after_pad = \ + intermediate_size_per_partition + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) + + self.intermediate_size = intermediate_size_per_partition_after_pad + self.hidden_size = hidden_size + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=scale_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_bias = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=scale_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=torch.bfloat16), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + sf_block_size = 32 # mxfp4 block size + + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + w13_weight_scale = layer.w13_weight_scale.data + w2_weight_scale = layer.w2_weight_scale.data + w13_weight = layer.w13_weight.data + w2_weight = layer.w2_weight.data + w13_bias = layer.w13_bias.data.to(torch.float32) + w2_bias = layer.w2_bias.data.to(torch.float32) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Do not interleave as the checkpoint is already interleaved + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(self.num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = torch.stack( + gemm1_scales_mxfp4_shuffled).reshape( + self.num_experts, 2 * self.intermediate_size, + self.hidden_size // sf_block_size).view( + torch.float8_e4m3fn) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + self.num_experts, self.hidden_size, self.intermediate_size // + sf_block_size).view(torch.float8_e4m3fn) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, + requires_grad=False) + layer.w13_bias = Parameter( + torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False) + layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( + self.num_experts, -1), + requires_grad=False) + return + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError("EPLB is not supported for mxfp4") + + assert _can_support_mxfp4( + use_grouped_topk, topk_group, num_expert_group, expert_map, + custom_routing_function, e_score_correction_bias, + apply_router_weight_on_input, scoring_func, activation, + expert_load_view, logical_to_physical_map, + logical_replica_count), ("MXFP4 are not supported\ + with this configuration.") + + if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + assert not self.moe.use_ep, ( + "EP is not supported for flashinfer mxfp4 moe backend yet.") + if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + else: + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( + router_logits.to(torch.bfloat16), + None, # routing_bias + x_quant, + x_scale, + layer.w13_weight, # uint8 (e2m1 x 2) + layer.w13_weight_scale, # uint8 (e4m3 x 2) + layer.w13_bias, # fp32 per expert per channel + layer.gemm1_alpha, # fp32 per expert + layer.gemm1_beta, # fp32 per expert + layer.gemm1_clamp_limit, # fp32 per expert + layer.w2_weight, # uint8 (e2m1 x 2) + layer.w2_weight_scale, # ue8m0 + layer.w2_bias, # fp32 per expert per channel + None, # output1_scale_scalar + None, # output1_scale_gate_scalar + None, # output2_scale_scalar + self.num_experts, + top_k, + None, # n_group + None, # topk_group + self.intermediate_size, # padded to multiple of 256 + 0, # local_expert_offset + self.num_experts, # local num experts + None, + self._get_tile_tokens_dim(x, top_k), + 1 if renormalize else 0, # routing_method_type, renormalize + True, # do finalize + )[0] + return trtllm_gen_output diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 1119045db0..4a4e199e13 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + import torch from vllm.utils import direct_register_custom_op @@ -7,6 +9,26 @@ from vllm.utils import direct_register_custom_op OCP_MX_BLOCK_SIZE = 32 +def _can_support_mxfp4(use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + scoring_func: str = "softmax", + activation: str = "silu", + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None): + return not (use_grouped_topk or topk_group or num_expert_group + or expert_map or custom_routing_function + or e_score_correction_bias or apply_router_weight_on_input + or scoring_func != "softmax" or activation != "silu" + or expert_load_view or logical_to_physical_map + or logical_replica_count) + + def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype) -> torch.Tensor: try: