[BugFix]Support redundant experts in EPLB (#3473)

This PR adds support for redundant experts in the EPLB. 

Key points: 
- Use global_num_experts = num_experts + num_redundant_experts
consistently.
- Backward compatible when num_redundant_experts=0. 

Tested 
On a 16-rank setup (W8A8) with static EPLB and expert_map_path,
verifying router logits shape and successful requests.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: yechao237 <yechao20180411@gmail.com>
This commit is contained in:
yechao237
2025-10-18 00:09:16 +08:00
committed by GitHub
parent 07ca1b9b78
commit 4750d45d86
12 changed files with 23 additions and 35 deletions

View File

@ -34,8 +34,8 @@ def test_determine_default_expert_map_multiple_worlds_with_redundant():
rank_id=0,
global_redundant_expert_num=1)
assert count == 3
assert torch.all(expert_map[0:3] >= 0)
assert count == 2
assert torch.all(expert_map[0:2] >= 0)
def test_generate_log2phy_map_single_rank_holding():

View File

@ -257,7 +257,7 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
@pytest.mark.parametrize("global_num_experts", [256, 128])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):

View File

@ -22,6 +22,7 @@ import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import (
@ -355,7 +356,9 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
"""
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
global_redundant_expert_num = get_ascend_config(
).init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(

View File

@ -40,14 +40,6 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id,
end = global_expert_num
local_count = global_expert_num - rank_id * local_num_experts
if isinstance(global_redundant_expert_num,
int) and rank_id < global_redundant_expert_num:
local_count += 1
if end < global_expert_num:
end += 1
else:
start -= 1
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map[start:end] = local_ids
@ -118,14 +110,6 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
end = global_expert_num
local_count = global_expert_num - r * local_num_experts
if isinstance(global_redundant_expert_num,
int) and rank_id < global_redundant_expert_num:
local_count += 1
if end < global_expert_num:
end += 1
else:
start -= 1
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[r, start:end] = local_ids

View File

@ -20,6 +20,8 @@ import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@ -176,7 +178,8 @@ def _select_experts_with_fusion_ops(
topk_weights, topk_ids = None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,

View File

@ -123,10 +123,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
):
if self.with_quant:
quant_mode = 2
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
moe_expert_num = len(expert_map)
else:
quant_mode = 0
moe_expert_num = len(expert_map)

View File

@ -263,7 +263,7 @@ class AscendW4A8DynamicFusedMoEMethod:
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(

View File

@ -263,7 +263,7 @@ class AscendW8A8FusedMoEMethod:
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,

View File

@ -203,7 +203,7 @@ class AscendW8A8DynamicFusedMoEMethod:
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,

View File

@ -856,8 +856,9 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
is_deepseek_v3_r1 = global_num_experts == 256
global_redundant_expert_num = get_ascend_config(
).init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(

View File

@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(

View File

@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2(
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
moe_expert_num = len(expert_map)
else:
moe_expert_num = global_redundant_expert_num
# hidden_states = hidden_states.bfloat16()
@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all(
if expert_map is not None:
assert ep_group is not None, "ep_group must be provided when expert_map is given"
global_num_experts = len(expert_map) + global_redundant_expert_num
global_num_experts = len(expert_map)
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
hidden_states,
@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
is_deepseek_v3_r1 = global_num_experts == 256
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1: