[MoE] [Refactor] Remove manual memory cleanup (#3365)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-15 12:36:24 +08:00
committed by GitHub
parent 4e720936d8
commit 4f937f561d
8 changed files with 562 additions and 492 deletions

View File

@ -137,6 +137,7 @@ def test_token_dispatcher_with_all_gather(
sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
@ -144,7 +145,9 @@ def test_token_dispatcher_with_all_gather(
group_list=group_list,
group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
@ -215,6 +218,7 @@ def test_token_dispatcher_with_all_gather_quant(
group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
@ -225,7 +229,9 @@ def test_token_dispatcher_with_all_gather_quant(
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
with_quant=True)
combined_output = dispatcher.token_combine(hidden_states=expert_output,
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k)
gc.collect()

View File

@ -44,7 +44,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states, router_logits)
# Check padding and split
self.assertEqual(h_out.shape[0], 4)
@ -52,7 +53,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(mask.tolist(), [1, 0, 1])
# Finalize
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@patch(
@ -77,7 +80,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2)
h_out, r_out, mask = layer.prepare(hidden_states,
h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
@ -96,7 +100,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back to original size
self.assertEqual(final_result.shape[0], 4)
@ -112,12 +118,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
# Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3)
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@patch(
@ -133,7 +142,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2)
h_out, r_out, _ = layer.prepare(hidden_states,
h_out, r_out, _, context_metadata = layer.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp=False,
replace_allreduce=False)
@ -152,7 +162,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out),
torch.zeros_like(h_out)
]
final_result = layer.finalize(h_out, reduce_results=False)
final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back
self.assertEqual(final_result.shape[0], 2)
@ -195,7 +207,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _ = layer.prepare(hidden_states,
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)
@ -209,7 +221,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
return tensor[:3]
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out, reduce_results=False)
result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3)
@ -263,7 +277,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare
h_out, r_out, _ = layer.prepare(hidden_states,
h_out, r_out, _, _ = layer.prepare(hidden_states,
router_logits,
gate=mock_gate)

View File

@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
@ -59,15 +59,18 @@ class TestMoECommMethod(TestBase):
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
comm_impl.finalize(h_out,
reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@ -90,7 +93,8 @@ class TestMoECommMethod(TestBase):
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2),
torch.tensor([1, 0, 1, 0]))
torch.tensor([1, 0, 1,
0]), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
@ -104,15 +108,18 @@ class TestMoECommMethod(TestBase):
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None)
# Test finalize method
comm_impl.finalize(h_out, reduce_results=True)
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
comm_impl.finalize(h_out,
reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@ -135,7 +142,7 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize
mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None)
torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance
@ -149,7 +156,8 @@ class TestMoECommMethod(TestBase):
# Test prepare method
hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with(

View File

@ -77,9 +77,10 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map)
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
@ -123,36 +124,64 @@ class TestTokenDispatcherWithMC2(TestBase):
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1) # 注意:应为 float不是 int
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10) # mock 值
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": mc2_mask,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
}
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
10, 128), torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
shared_experts = MagicMock()
shared_experts.down_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
}
self.dispatcher.with_quant = True
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
self.dispatcher.token_combine(self.hidden_states)
result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)
class TestTokenDispatcherWithAllGather(TestBase):
@ -264,35 +293,26 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.assertEqual(results["group_list_type"], 1)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
context_metadata = {
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.mock_npu_moe_token_unpermute.assert_called_once()
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self):
@ -418,25 +438,21 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self):
hidden_states = torch.randn(16, 16)
context_metadata = {
"input_splits": [4, 4],
"output_splits": [4, 4],
"topk_weights": torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16),
}
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))

View File

@ -283,7 +283,7 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled,
@ -311,7 +311,8 @@ class AscendFusedMoE(FusedMoE):
shared_experts=None,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num)
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
@ -322,7 +323,8 @@ class AscendFusedMoE(FusedMoE):
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
reduce_results=self.reduce_results,
context_metadata=context_metadata)
return final_hidden_states

View File

@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
@ -49,12 +50,15 @@ class FusedMoEPrepareAndFinalize(ABC):
is_deepseek_v3_r1)
@abstractmethod
def prepare(self,
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@ -74,11 +78,14 @@ class FusedMoEPrepareAndFinalize(ABC):
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization)
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
@ -96,9 +103,102 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using MC2 (Memory-Centric Communication).
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states}
return hidden_states, router_logits, None, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
assert context_metadata is not None
split_hidden_states = context_metadata["split_hidden_states"]
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
"""
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method.
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
@ -116,12 +216,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@ -132,10 +235,11 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
@ -165,124 +269,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states # Save for finalize
return hidden_states, router_logits, mc2_mask
context_metadata = {"split_hidden_states": split_hidden_states}
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
return hidden_states, router_logits, mc2_mask, context_metadata
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
@ -292,12 +282,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
Uses `max_tokens_across_dp` from forward_context for padding alignment.
"""
def prepare(self,
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@ -305,7 +298,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
3. All-gather across DP group to form global input tensor.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
Tuple of (global_hidden_states, global_router_logits, None, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
@ -331,10 +324,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None
return hidden_states, router_logits, None, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
@ -395,19 +390,22 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def prepare(self,
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
Tuple of (global_hidden_states, global_router_logits, None, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
@ -422,10 +420,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None
return hidden_states, router_logits, None, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert:

View File

@ -57,28 +57,31 @@ class MoECommMethod(ABC):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.mc2_mask = None
self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
)
def prepare(self,
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, gate)
self.mc2_mask = mc2_mask
return hidden_states, router_logits
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize(
hidden_states, reduce_results)
hidden_states, reduce_results, context_metadata)
return hidden_states
def fused_experts(
@ -108,7 +111,8 @@ class MoECommMethod(ABC):
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False):
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None):
# Check constraints
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
@ -127,12 +131,12 @@ class MoECommMethod(ABC):
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=self.mc2_mask,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
@ -152,7 +156,7 @@ class MoECommMethod(ABC):
dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output)
hidden_states=mlp_output, context_metadata=context_metadata)
if dynamic_eplb:
return (final_hidden_states, group_list_type, expert_tokens)

View File

@ -75,6 +75,7 @@ class MoETokenDispatcher(ABC):
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
@ -102,16 +103,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False
self.expand_scales = None
def get_dispatch_mc2_kwargs(
self,
@ -119,6 +111,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
@ -155,7 +148,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": self.mc2_mask,
"x_active_mask": mc2_mask,
})
if self.need_expert_scale:
stage1_kwargs.update({
@ -166,7 +159,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(self,
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@ -178,87 +172,108 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
mc2_mask,
global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, _, expand_scales = output[0:7]
if self.with_quant:
# Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
shared_act, swiglu_out_scale = shared_act_out[
0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 0
shared_act = shared_experts.act_fn(shared_gate_up)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
return {
"group_list_type": group_list_type,
"group_list_type": 0,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
assert self.topk_weights is not None
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
# moeCombine
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"]
ep_recv_counts = context_metadata["ep_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]
assert expert_map is not None
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"expand_x": hidden_states,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights.to(torch.float32),
"expert_ids": topk_ids,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
else:
tp_recv_counts = self.output[5]
tp_recv_counts = ep_recv_counts
stage3_kwargs = {
"ep_send_counts": self.ep_recv_counts,
"ep_send_counts": ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
"expand_scales": self.expand_scales,
"expand_scales": expand_scales,
}
if self.enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
self.assist_info_for_combine,
})
stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine
else:
stage3_kwargs.update({
"expand_idx": self.assist_info_for_combine,
})
stage3_kwargs["expand_idx"] = assist_info_for_combine
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
@ -266,45 +281,40 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": self.mc2_mask,
})
stage3_kwargs["x_active_mask"] = mc2_mask
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(self,
def token_combine(
self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# these values are no longer used, so they need to be set to None for memory release.
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.topk_ids = None
self.topk_weights = None
self.mc2_mask = None
self.expert_map = None
self.expand_scales = None
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
context_metadata)
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
if self.shared_experts is None:
return hidden_states
else:
# Handle shared experts from metadata
shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
self.shared_act = None
self.shared_experts = None
self.swiglu_out_scale = None
return hidden_states, shared_hidden_states
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return combined_output, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher):
@ -314,14 +324,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
self.original_shape = None
self.mask = None
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False
def token_dispatch(self,
@ -341,9 +344,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
@ -357,7 +357,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
if expert_map is not None:
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
@ -366,7 +366,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
@ -379,29 +379,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=torch.abs(self.expanded_row_idx),
probs=self.topk_weights)
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"])
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release.
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.expanded_row_idx = None
return final_hidden_states
@ -450,11 +452,12 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
"topk_scales": topk_scales
}
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
@ -478,19 +481,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
@ -512,7 +504,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(self,
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@ -524,84 +517,103 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None:
topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
hidden_states, topk_ids)
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
(
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale,
self.output_splits,
self.input_splits,
self.ep_group,
)
dynamic_scale, output_splits, input_splits, self.ep_group)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
self.ep_group,
)
permutated_local_input_tokens, output_splits, input_splits,
self.ep_group)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all)
# Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices)
context_metadata = {
"input_splits":
input_splits,
"output_splits":
output_splits,
"topk_weights":
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
}
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale,
"group_list_type": 1
"group_list_type": 1,
"dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
}
def token_combine(self,
def token_combine(
self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
hidden_states = self._combine_preprocess(hidden_states)
# 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states,
context_metadata)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
# 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits,
self.ep_group)
hidden_states,
context_metadata["input_splits"],
context_metadata["output_splits"],
self.ep_group,
)
handle.wait()
hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens)
# these values are no longer used, so they need to be set to None for memory release.
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
self.topk_weights = None
self.reversed_local_input_permutation_mapping = None
self.reversed_global_input_permutation_mapping = None
self.global_input_tokens_local_experts_indices = None
# 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self._preprocess(topk_ids)
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
(
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
@ -610,82 +622,88 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
return (
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel()
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
axis=0)
global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
num_global_tokens_per_local_expert.ravel())
else:
# TODO: This full synchronization can be a performance bottleneck.
# A more granular sync (e.g., blocking D2H copies) should be investigated.
torch.npu.synchronize()
return num_tokens_per_local_expert
return (
num_tokens_per_local_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
def _dispatch_postprocess(self, global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, None
return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case
if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
assert global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be provided"
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = self.global_input_tokens_local_experts_indices.numel()
active_num = global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
# Process with active tokens
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale,
scale=dynamic_scale_after_all2all,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
@ -693,32 +711,34 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0)
return global_input_tokens, expanded_scale
row_idx_type=0,
)
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
# Handle non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens,
self.global_input_tokens_local_experts_indices)
return global_input_tokens, None
# Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices)
return global_input_tokens, None, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states):
def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[
"reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping)
hidden_states, rev_global)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens):
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to(
torch.int32),
probs=self.topk_weights,
restore_shape=self.hidden_shape_before_permute)
# Reshape the output tensor
sorted_indices=context_metadata[
"reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute,
)
output = output.view(self.hidden_shape)
return output