mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[BugFix]Support redundant experts in EPLB (#3473)
This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 <yechao20180411@gmail.com>
This commit is contained in:
@ -34,8 +34,8 @@ def test_determine_default_expert_map_multiple_worlds_with_redundant():
|
||||
rank_id=0,
|
||||
global_redundant_expert_num=1)
|
||||
|
||||
assert count == 3
|
||||
assert torch.all(expert_map[0:3] >= 0)
|
||||
assert count == 2
|
||||
assert torch.all(expert_map[0:2] >= 0)
|
||||
|
||||
|
||||
def test_generate_log2phy_map_single_rank_holding():
|
||||
|
@ -257,7 +257,7 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||
global_num_experts):
|
||||
|
||||
|
@ -22,6 +22,7 @@ import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||
@ -355,7 +356,9 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
"""
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
global_redundant_expert_num = get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
with patch(
|
||||
|
@ -40,14 +40,6 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - rank_id * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map[start:end] = local_ids
|
||||
@ -118,14 +110,6 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - r * local_num_experts
|
||||
|
||||
if isinstance(global_redundant_expert_num,
|
||||
int) and rank_id < global_redundant_expert_num:
|
||||
local_count += 1
|
||||
if end < global_expert_num:
|
||||
end += 1
|
||||
else:
|
||||
start -= 1
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[r, start:end] = local_ids
|
||||
|
@ -20,6 +20,8 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@ -176,7 +178,8 @@ def _select_experts_with_fusion_ops(
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
|
@ -123,10 +123,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
):
|
||||
if self.with_quant:
|
||||
quant_mode = 2
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
moe_expert_num = len(expert_map)
|
||||
else:
|
||||
quant_mode = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
|
@ -263,7 +263,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids = select_experts(
|
||||
|
@ -263,7 +263,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
@ -203,7 +203,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
@ -856,8 +856,9 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
global_redundant_expert_num = get_ascend_config(
|
||||
).init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
|
@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
|
@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2(
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
moe_expert_num = len(expert_map)
|
||||
else:
|
||||
moe_expert_num = global_redundant_expert_num
|
||||
# hidden_states = hidden_states.bfloat16()
|
||||
@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all(
|
||||
|
||||
if expert_map is not None:
|
||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
global_num_experts = len(expert_map)
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
hidden_states,
|
||||
@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if is_deepseek_v3_r1:
|
||||
|
Reference in New Issue
Block a user