mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[BugFix]Support redundant experts in EPLB (#3473)
This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 <yechao20180411@gmail.com>
This commit is contained in:
@ -34,8 +34,8 @@ def test_determine_default_expert_map_multiple_worlds_with_redundant():
|
|||||||
rank_id=0,
|
rank_id=0,
|
||||||
global_redundant_expert_num=1)
|
global_redundant_expert_num=1)
|
||||||
|
|
||||||
assert count == 3
|
assert count == 2
|
||||||
assert torch.all(expert_map[0:3] >= 0)
|
assert torch.all(expert_map[0:2] >= 0)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_log2phy_map_single_rank_holding():
|
def test_generate_log2phy_map_single_rank_holding():
|
||||||
|
@ -257,7 +257,7 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
class TestExpertsSelector:
|
class TestExpertsSelector:
|
||||||
|
|
||||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
||||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||||
global_num_experts):
|
global_num_experts):
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ import torch_npu
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||||
@ -355,7 +356,9 @@ class TestTorchairAscendUnquantizedFusedMoEMethod:
|
|||||||
"""
|
"""
|
||||||
global_num_experts, ep_size = others_param
|
global_num_experts, ep_size = others_param
|
||||||
is_prefill = False
|
is_prefill = False
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
global_redundant_expert_num = get_ascend_config(
|
||||||
|
).init_redundancy_expert
|
||||||
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -40,14 +40,6 @@ def determine_default_expert_map(global_expert_num, world_size, rank_id,
|
|||||||
end = global_expert_num
|
end = global_expert_num
|
||||||
local_count = global_expert_num - rank_id * local_num_experts
|
local_count = global_expert_num - rank_id * local_num_experts
|
||||||
|
|
||||||
if isinstance(global_redundant_expert_num,
|
|
||||||
int) and rank_id < global_redundant_expert_num:
|
|
||||||
local_count += 1
|
|
||||||
if end < global_expert_num:
|
|
||||||
end += 1
|
|
||||||
else:
|
|
||||||
start -= 1
|
|
||||||
|
|
||||||
if isinstance(local_count, int):
|
if isinstance(local_count, int):
|
||||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||||
expert_map[start:end] = local_ids
|
expert_map[start:end] = local_ids
|
||||||
@ -118,14 +110,6 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
|
|||||||
end = global_expert_num
|
end = global_expert_num
|
||||||
local_count = global_expert_num - r * local_num_experts
|
local_count = global_expert_num - r * local_num_experts
|
||||||
|
|
||||||
if isinstance(global_redundant_expert_num,
|
|
||||||
int) and rank_id < global_redundant_expert_num:
|
|
||||||
local_count += 1
|
|
||||||
if end < global_expert_num:
|
|
||||||
end += 1
|
|
||||||
else:
|
|
||||||
start -= 1
|
|
||||||
|
|
||||||
if isinstance(local_count, int):
|
if isinstance(local_count, int):
|
||||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||||
expert_map_all[r, start:end] = local_ids
|
expert_map_all[r, start:end] = local_ids
|
||||||
|
@ -20,6 +20,8 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
|
||||||
|
|
||||||
def select_experts(hidden_states: torch.Tensor,
|
def select_experts(hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
@ -176,7 +178,8 @@ def _select_experts_with_fusion_ops(
|
|||||||
|
|
||||||
topk_weights, topk_ids = None, None
|
topk_weights, topk_ids = None, None
|
||||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||||
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
if is_deepseek_v3_r1:
|
if is_deepseek_v3_r1:
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
|
@ -123,10 +123,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
):
|
):
|
||||||
if self.with_quant:
|
if self.with_quant:
|
||||||
quant_mode = 2
|
quant_mode = 2
|
||||||
if (expert_map is not None):
|
moe_expert_num = len(expert_map)
|
||||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
|
||||||
else:
|
|
||||||
moe_expert_num = global_redundant_expert_num
|
|
||||||
else:
|
else:
|
||||||
quant_mode = 0
|
quant_mode = 0
|
||||||
moe_expert_num = len(expert_map)
|
moe_expert_num = len(expert_map)
|
||||||
|
@ -263,7 +263,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts, "Number of global experts mismatch"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
|
@ -263,7 +263,7 @@ class AscendW8A8FusedMoEMethod:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts, "Number of global experts mismatch"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
@ -203,7 +203,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts, "Number of global experts mismatch"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
@ -856,8 +856,9 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
global_redundant_expert_num = get_ascend_config(
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
).init_redundancy_expert
|
||||||
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if is_deepseek_v3_r1:
|
if is_deepseek_v3_r1:
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
|
@ -269,7 +269,7 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts, "Number of global experts mismatch"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
if global_num_experts == 256:
|
if global_num_experts == 256:
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
|
@ -246,7 +246,7 @@ def torchair_fused_experts_with_mc2(
|
|||||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||||
|
|
||||||
if (expert_map is not None):
|
if (expert_map is not None):
|
||||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
moe_expert_num = len(expert_map)
|
||||||
else:
|
else:
|
||||||
moe_expert_num = global_redundant_expert_num
|
moe_expert_num = global_redundant_expert_num
|
||||||
# hidden_states = hidden_states.bfloat16()
|
# hidden_states = hidden_states.bfloat16()
|
||||||
@ -431,7 +431,7 @@ def torchair_fused_experts_with_all2all(
|
|||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
global_num_experts = len(expert_map)
|
||||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -929,9 +929,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts, "Number of global experts mismatch"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if is_deepseek_v3_r1:
|
if is_deepseek_v3_r1:
|
||||||
|
Reference in New Issue
Block a user