From 4750d45d86632b7085ab9f0b57070cd17aa9d108 Mon Sep 17 00:00:00 2001 From: yechao237 <38792318+yechao237@users.noreply.github.com> Date: Sat, 18 Oct 2025 00:09:16 +0800 Subject: [PATCH] [BugFix]Support redundant experts in EPLB (#3473) This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 --- tests/ut/eplb/core/test_eplb_utils.py | 4 ++-- tests/ut/ops/test_fused_ops.py | 2 +- tests/ut/torchair/ops/test_torchair_fused_moe.py | 5 ++++- vllm_ascend/eplb/core/eplb_utils.py | 16 ---------------- vllm_ascend/ops/moe/experts_selector.py | 5 ++++- vllm_ascend/ops/moe/token_dispatcher.py | 5 +---- vllm_ascend/quantization/w4a8_dynamic.py | 2 +- vllm_ascend/quantization/w8a8.py | 2 +- vllm_ascend/quantization/w8a8_dynamic.py | 2 +- vllm_ascend/torchair/ops/torchair_fused_moe.py | 5 +++-- .../quantization/torchair_w4a8_dynamic.py | 2 +- .../quantization/torchair_w8a8_dynamic.py | 8 ++++---- 12 files changed, 23 insertions(+), 35 deletions(-) diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py index 8a9761f97..624b1d2e6 100644 --- a/tests/ut/eplb/core/test_eplb_utils.py +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -34,8 +34,8 @@ def test_determine_default_expert_map_multiple_worlds_with_redundant(): rank_id=0, global_redundant_expert_num=1) - assert count == 3 - assert torch.all(expert_map[0:3] >= 0) + assert count == 2 + assert torch.all(expert_map[0:2] >= 0) def test_generate_log2phy_map_single_rank_holding(): diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 685fbfcda..a6b0ae024 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -257,7 +257,7 @@ class MockFusedMoEMethod(FusedMoEMethodBase): class TestExpertsSelector: - @pytest.mark.parametrize("global_num_experts", [[256], [128]]) + @pytest.mark.parametrize("global_num_experts", [256, 128]) def test_select_experts(self, mock_dist_env, mock_moe_env, global_num_experts): diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index e8945d8af..705c794cf 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -22,6 +22,7 @@ import torch_npu from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.ops.torchair_fused_moe import ( @@ -355,7 +356,9 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: """ global_num_experts, ep_size = others_param is_prefill = False - is_deepseek_v3_r1 = global_num_experts == 256 + global_redundant_expert_num = get_ascend_config( + ).init_redundancy_expert + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( ep_size, is_prefill, is_deepseek_v3_r1)) with patch( diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py index 9a1c3bd4b..af87c9113 100644 --- a/vllm_ascend/eplb/core/eplb_utils.py +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -40,14 +40,6 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id, end = global_expert_num local_count = global_expert_num - rank_id * local_num_experts - if isinstance(global_redundant_expert_num, - int) and rank_id < global_redundant_expert_num: - local_count += 1 - if end < global_expert_num: - end += 1 - else: - start -= 1 - if isinstance(local_count, int): local_ids = torch.arange(local_count, dtype=torch.int32) expert_map[start:end] = local_ids @@ -118,14 +110,6 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id, end = global_expert_num local_count = global_expert_num - r * local_num_experts - if isinstance(global_redundant_expert_num, - int) and rank_id < global_redundant_expert_num: - local_count += 1 - if end < global_expert_num: - end += 1 - else: - start -= 1 - if isinstance(local_count, int): local_ids = torch.arange(local_count, dtype=torch.int32) expert_map_all[r, start:end] = local_ids diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index 7beda17fe..e511d6b55 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,6 +20,8 @@ import torch import torch_npu from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_config import get_ascend_config + def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -176,7 +178,8 @@ def _select_experts_with_fusion_ops( topk_weights, topk_ids = None, None # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - is_deepseek_v3_r1 = global_num_experts == 256 + global_redundant_expert_num = get_ascend_config().init_redundancy_expert + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 if is_deepseek_v3_r1: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 3dd799a42..83da546a1 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -123,10 +123,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): ): if self.with_quant: quant_mode = 2 - if (expert_map is not None): - moe_expert_num = len(expert_map) + global_redundant_expert_num - else: - moe_expert_num = global_redundant_expert_num + moe_expert_num = len(expert_map) else: quant_mode = 0 moe_expert_num = len(expert_map) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index c8c1eeb69..4b5632fa3 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -263,7 +263,7 @@ class AscendW4A8DynamicFusedMoEMethod: **kwargs, ) -> torch.Tensor: assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern topk_weights, topk_ids = select_experts( diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index fec542cd2..615bece0c 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -263,7 +263,7 @@ class AscendW8A8FusedMoEMethod: **kwargs, ) -> torch.Tensor: assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" topk_weights, topk_ids = select_experts( hidden_states=x, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 598e0e50a..8055b5345 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -203,7 +203,7 @@ class AscendW8A8DynamicFusedMoEMethod: **kwargs, ) -> torch.Tensor: assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" topk_weights, topk_ids = select_experts( hidden_states=x, diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index a54700836..1a87f3e89 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -856,8 +856,9 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): shared_experts: Optional[Any] = None, **kwargs, ) -> torch.Tensor: - - is_deepseek_v3_r1 = global_num_experts == 256 + global_redundant_expert_num = get_ascend_config( + ).init_redundancy_expert + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if is_deepseek_v3_r1: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index ff7b0eeda..732e94349 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: **kwargs, ) -> torch.Tensor: assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 573933a16..0f4615443 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2( enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") if (expert_map is not None): - moe_expert_num = len(expert_map) + global_redundant_expert_num + moe_expert_num = len(expert_map) else: moe_expert_num = global_redundant_expert_num # hidden_states = hidden_states.bfloat16() @@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all( if expert_map is not None: assert ep_group is not None, "ep_group must be provided when expert_map is given" - global_num_experts = len(expert_map) + global_redundant_expert_num + global_num_experts = len(expert_map) if hasattr(torch_npu, "npu_moe_init_routing_quant"): quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( hidden_states, @@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: **kwargs, ) -> torch.Tensor: assert router_logits.shape[ - 1] == global_num_experts, "Number of global experts mismatch" + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - is_deepseek_v3_r1 = global_num_experts == 256 + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if is_deepseek_v3_r1: