From 0925b28a8e92855131a1c1308a16ec14a9c94ceb Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:06:33 -0700 Subject: [PATCH] [ROCM] MoE fp4 CK kernel (#26545) Signed-off-by: Aleksandr Malyshev Co-authored-by: Aleksandr Malyshev --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 5 + .../layers/quantization/quark/quark_moe.py | 92 ++++++++++++++----- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b572baecd7..820c0af71c 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool: ) +@cache +def use_mxfp4_aiter_moe() -> bool: + return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + + @cache def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: return ( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index c13cf7007e..5cab6e205c 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, + use_mxfp4_aiter_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, @@ -472,22 +473,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): "not implemented. Please open an issue." ) - if not current_platform.supports_mx(): - self.emulate = True + self.emulate = not current_platform.supports_mx() or not ( + use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + ) + if self.emulate: logger.warning_once( - "The current platform does not support native MXFP4/MXFP6 " + f"The current mode (supports_mx={current_platform.supports_mx()}, " + f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"ocp_mx_scheme={self.ocp_mx_scheme}) " + "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision." ) else: - self.emulate = True logger.warning_once( - "The current platform supports native MXFP4/MXFP6 " - "computation, but kernels are not yet integrated in vLLM. " - "Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." + "The current mode supports native MoE MXFP4 computation" ) def get_packed_dim(self, dim: int, quant_dtype: str): @@ -568,6 +569,24 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + def process_weights_after_loading(self, layer): + if self.emulate: + return + + from aiter.utility.fp4_utils import e8m0_shuffle + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + torch.cuda.empty_cache() + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -611,8 +630,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." ) - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -628,17 +645,44 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): indices_type=self.topk_indices_dtype, ) - out = fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - quant_config=self.moe_quant_config, - ) + if not self.emulate: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + aiter_acts = { + ActivationType.No.name.lower(): ActivationType.No, + ActivationType.Silu.name.lower(): ActivationType.Silu, + ActivationType.Gelu.name.lower(): ActivationType.Gelu, + } + assert activation in aiter_acts, ( + f"Aiter CK fp4 MoE doesn't support activation {activation}" + ) + out = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=aiter_acts[activation], + doweight_stage1=False, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) return out