From 4f937f561d573ae97f953169865cfbf70d0c220b Mon Sep 17 00:00:00 2001 From: weichen <132029610+Pr0Wh1teGivee@users.noreply.github.com> Date: Wed, 15 Oct 2025 12:36:24 +0800 Subject: [PATCH] [MoE] [Refactor] Remove manual memory cleanup (#3365) ### What this PR does / why we need it? 1. Replace manual memory cleanup with passing parameter. 2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Pr0Wh1teGivee --- tests/e2e/singlecard/ops/test_fused_moe.py | 14 +- .../test_fused_moe_prepare_and_finalize.py | 56 ++- tests/ut/ops/test_moe_comm_method.py | 28 +- tests/ut/ops/test_token_dispatcher.py | 126 +++-- vllm_ascend/ops/common_fused_moe.py | 8 +- .../ops/moe/fused_moe_prepare_and_finalize.py | 306 +++++------ vllm_ascend/ops/moe/moe_comm_method.py | 40 +- vllm_ascend/ops/moe/token_dispatcher.py | 476 +++++++++--------- 8 files changed, 562 insertions(+), 492 deletions(-) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 4735a5f15..fae3ecb09 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -137,6 +137,7 @@ def test_token_dispatcher_with_all_gather( sorted_hidden_states = dispatch_output["hidden_states"] group_list = dispatch_output["group_list"] group_list_type = dispatch_output.get("group_list_type", 1) + context_metadata = dispatch_output["context_metadata"] expert_output = apply_mlp(hidden_states=sorted_hidden_states, w1=w1_local, @@ -144,8 +145,10 @@ def test_token_dispatcher_with_all_gather( group_list=group_list, group_list_type=group_list_type) - combined_output = dispatcher.token_combine(hidden_states=expert_output, - bias=None) + combined_output = dispatcher.token_combine( + hidden_states=expert_output, + context_metadata=context_metadata, + bias=None) torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map) @@ -215,6 +218,7 @@ def test_token_dispatcher_with_all_gather_quant( group_list = dispatch_output["group_list"] group_list_type = dispatch_output.get("group_list_type", 1) dynamic_scale = dispatch_output["dynamic_scale"] + context_metadata = dispatch_output["context_metadata"] expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, w1=w1, @@ -225,8 +229,10 @@ def test_token_dispatcher_with_all_gather_quant( group_list_type=group_list_type, dynamic_scale=dynamic_scale, with_quant=True) - combined_output = dispatcher.token_combine(hidden_states=expert_output, - bias=None) + combined_output = dispatcher.token_combine( + hidden_states=expert_output, + context_metadata=context_metadata, + bias=None) assert combined_output.shape == (m, k) gc.collect() torch.npu.empty_cache() diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py index 3a9733b3f..93b73ecfa 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -44,7 +44,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mask = layer.prepare(hidden_states, router_logits) + h_out, r_out, mask, context_metadata = layer.prepare( + hidden_states, router_logits) # Check padding and split self.assertEqual(h_out.shape[0], 4) @@ -52,7 +53,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): self.assertEqual(mask.tolist(), [1, 0, 1]) # Finalize - result = layer.finalize(h_out, reduce_results=False) + result = layer.finalize(h_out, + reduce_results=False, + context_metadata=context_metadata) self.assertEqual(result.shape[0], 3) @patch( @@ -77,10 +80,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(4, 8) router_logits = torch.randn(4, 2) - h_out, r_out, mask = layer.prepare(hidden_states, - router_logits, - enable_shared_expert_dp=False, - replace_allreduce=False) + h_out, r_out, mask, context_metadata = layer.prepare( + hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) # With TP=2, should split into 2 parts self.assertEqual(h_out.shape[0], 2) @@ -96,7 +100,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): torch.zeros_like(h_out), torch.zeros_like(h_out) ] - final_result = layer.finalize(h_out, reduce_results=False) + final_result = layer.finalize(h_out, + reduce_results=False, + context_metadata=context_metadata) # Should concat back to original size self.assertEqual(final_result.shape[0], 4) @@ -112,12 +118,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, _ = layer.prepare(hidden_states, router_logits) + h_out, r_out, _, context_metadata = layer.prepare( + hidden_states, router_logits) # Pad to tp_size=1, so no change self.assertEqual(h_out.shape[0], 3) - result = layer.finalize(h_out, reduce_results=False) + result = layer.finalize(h_out, + reduce_results=False, + context_metadata=context_metadata) self.assertEqual(result.shape[0], 3) @patch( @@ -133,10 +142,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(2, 8) router_logits = torch.randn(2, 2) - h_out, r_out, _ = layer.prepare(hidden_states, - router_logits, - enable_shared_expert_dp=False, - replace_allreduce=False) + h_out, r_out, _, context_metadata = layer.prepare( + hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) # Split due to TP=2 self.assertEqual(h_out.shape[0], 1) @@ -152,7 +162,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): torch.zeros_like(h_out), torch.zeros_like(h_out) ] - final_result = layer.finalize(h_out, reduce_results=False) + final_result = layer.finalize(h_out, + reduce_results=False, + context_metadata=context_metadata) # Should concat back self.assertEqual(final_result.shape[0], 2) @@ -195,9 +207,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_gate = MagicMock() mock_gate.return_value = (router_logits.repeat(2, 1), None) - h_out, r_out, _ = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _, context_metadata = layer.prepare(hidden_states, + router_logits, + gate=mock_gate) # After all-gather with DP=2, should double the batch size self.assertEqual(h_out.shape[0], 12) @@ -209,7 +221,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): return tensor[:3] mock_dp_group.reduce_scatter = mock_reduce_scatter_func - result = layer.finalize(h_out, reduce_results=False) + result = layer.finalize(h_out, + reduce_results=False, + context_metadata=context_metadata) self.assertEqual(result.shape[0], 3) @@ -263,9 +277,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_gate.return_value = (torch.randn(7, 2), None) # Run prepare - h_out, r_out, _ = layer.prepare(hidden_states, - router_logits, - gate=mock_gate) + h_out, r_out, _, _ = layer.prepare(hidden_states, + router_logits, + gate=mock_gate) # Should be global tensor: [7, 8] and [7, 2] self.assertEqual(h_out.shape, (7, 8)) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 3826a19c1..a3ef44104 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None) + torch.randn(4, 2), None, None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -59,15 +59,18 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( + hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, None) # Test finalize method - comm_impl.finalize(h_out, reduce_results=True) - mock_pf_instance.finalize.assert_called_once_with(h_out, True) + comm_impl.finalize(h_out, + reduce_results=True, + context_metadata=context_metadata) + mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @@ -90,7 +93,8 @@ class TestMoECommMethod(TestBase): mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), torch.randn(4, 2), - torch.tensor([1, 0, 1, 0])) + torch.tensor([1, 0, 1, + 0]), None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -104,15 +108,18 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( + hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, None) # Test finalize method - comm_impl.finalize(h_out, reduce_results=True) - mock_pf_instance.finalize.assert_called_once_with(h_out, True) + comm_impl.finalize(h_out, + reduce_results=True, + context_metadata=context_metadata) + mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @@ -135,7 +142,7 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None) + torch.randn(4, 2), None, None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -149,7 +156,8 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( + hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 87f384fad..486696c0d 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -77,9 +77,10 @@ class TestTokenDispatcherWithMC2(TestBase): topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + mc2_mask = None kwargs = self.dispatcher.get_dispatch_mc2_kwargs( - hidden_states, topk_weights, topk_ids, expert_map) + hidden_states, topk_weights, topk_ids, expert_map, mc2_mask) self.assertIn("x", kwargs) self.assertIn("expert_ids", kwargs) self.assertEqual(kwargs["moe_expert_num"], 8) @@ -123,36 +124,64 @@ class TestTokenDispatcherWithMC2(TestBase): def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.with_quant = True hidden_states = torch.randn(10, 128) - self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) - self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) - self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + topk_ids = torch.randint(0, 8, (10, 1)) + topk_weights = torch.randn(10, 1) # 注意:应为 float,不是 int + expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + mc2_mask = None + assist_info_for_combine = torch.arange(10) # mock 值 + + context_metadata = { + "topk_ids": topk_ids, + "topk_weights": topk_weights, + "expert_map": expert_map, + "ep_recv_counts": ep_recv_counts, + "mc2_mask": mc2_mask, + "assist_info_for_combine": assist_info_for_combine, + "expand_scales": None, + } + self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True - self.dispatcher.output = torch.randint(0, 8, (10, 1)) - kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states) + kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, + context_metadata) self.assertIn("tp_send_counts", kwargs) def test_token_combine_with_shared_experts(self): - self.dispatcher.shared_experts = MagicMock() - self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( - 10, 128), torch.tensor(1.0)) - self.dispatcher.shared_act = torch.randn(10, 128) + shared_experts = MagicMock() + shared_experts.down_proj.return_value = (torch.randn(10, 128), + torch.tensor(1.0)) + + topk_ids = torch.randint(0, 8, (10, 1)) + topk_weights = torch.randn(10, 1) + expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + assist_info_for_combine = torch.arange(10) + + context_metadata = { + "topk_ids": topk_ids, + "topk_weights": topk_weights, + "expert_map": expert_map, + "ep_recv_counts": ep_recv_counts, + "mc2_mask": None, + "assist_info_for_combine": assist_info_for_combine, + "expand_scales": None, + "shared_experts": shared_experts, + "shared_act": torch.randn(10, 128), + "swiglu_out_scale": torch.randn(10, 1), + } + self.dispatcher.with_quant = True - self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) - self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) - self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True - self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1)) - self.dispatcher.output = torch.randint(0, 8, (10, 1)) - self.hidden_states = torch.randn(10, 128) + hidden_states = torch.randn(10, 128) with patch("torch_npu.npu_moe_distribute_combine_v2", return_value=torch.randn(10, 128)): - self.dispatcher.token_combine(self.hidden_states) + result = self.dispatcher.token_combine(hidden_states, + context_metadata) + self.assertIsInstance(result, tuple) class TestTokenDispatcherWithAllGather(TestBase): @@ -264,35 +293,26 @@ class TestTokenDispatcherWithAllGather(TestBase): self.assertEqual(results["group_list_type"], 1) def test_token_combine_with_expert_map(self): - self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) - self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) - self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) - self.dispatcher.sorted_weights = torch.tensor( - [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) - self.dispatcher.original_shape = (3, 128) - self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - - final_hidden_states = self.dispatcher.token_combine(hidden_states) + context_metadata = { + "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), + "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + } + self.dispatcher.original_shape = (6, 128) + final_hidden_states = self.dispatcher.token_combine( + hidden_states, context_metadata) self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): - self.dispatcher.with_quant = False - self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) - self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) - self.dispatcher.sorted_weights = torch.tensor( - [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) - self.dispatcher.original_shape = (3, 128) - self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) - - final_hidden_states = self.dispatcher.token_combine(hidden_states) - - # Verify npu_moe_finalize_routing is called + context_metadata = { + "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), + "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + } + self.dispatcher.original_shape = (6, 128) + final_hidden_states = self.dispatcher.token_combine( + hidden_states, context_metadata) self.mock_npu_moe_token_unpermute.assert_called_once() - args, kwargs = self.mock_npu_moe_token_unpermute.call_args - self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_dispatch_with_router_weight(self): @@ -418,25 +438,21 @@ class TestTokenDispatcherWithAll2AllV(TestBase): self.assertEqual(result["group_list_type"], 1) def test_token_combine(self): + hidden_states = torch.randn(16, 16) + context_metadata = { + "input_splits": [4, 4], + "output_splits": [4, 4], + "topk_weights": torch.rand(8, 4), + "reversed_local_input_permutation_mapping": torch.arange(8), + "reversed_global_input_permutation_mapping": torch.arange(16), + } self.dispatcher.hidden_shape = (8, 16) self.dispatcher.hidden_shape_before_permute = (8, 16) - self.dispatcher.reversed_local_input_permutation_mapping = torch.arange( - 8) - self.dispatcher.topk_weights = torch.rand(8, 4) - self.dispatcher.input_splits = [4, 4] - self.dispatcher.output_splits = [4, 4] - self.dispatcher.reversed_global_input_permutation_mapping = torch.arange( - 16) - self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] - self.dispatcher.num_global_tokens_per_local_expert = torch.tensor( - [[2, 2], [2, 2]], dtype=torch.int64) - - expert_output = torch.randn(16, 16) - output = self.dispatcher.token_combine(expert_output) + output = self.dispatcher.token_combine(hidden_states, context_metadata) self.assertIsNotNone(output) self.assertEqual(output.shape, (8, 16)) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 90131878a..dd2b16616 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -283,7 +283,7 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance = forward_context.in_profile_run forward_context = get_forward_context() - hidden_states, router_logits = forward_context.moe_comm_method.prepare( + hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, replace_allreduce=forward_context.sp_enabled, @@ -311,7 +311,8 @@ class AscendFusedMoE(FusedMoE): shared_experts=None, enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, - global_redundant_expert_num=self.global_redundant_expert_num) + global_redundant_expert_num=self.global_redundant_expert_num, + mc2_mask=mc2_mask) if isinstance(final_hidden_states, tuple): final_hidden_states, group_list_type, expert_tokens = final_hidden_states @@ -322,7 +323,8 @@ class AscendFusedMoE(FusedMoE): final_hidden_states = forward_context.moe_comm_method.finalize( hidden_states=final_hidden_states, - reduce_results=self.reduce_results) + reduce_results=self.reduce_results, + context_metadata=context_metadata) return final_hidden_states diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index 15595a53a..415c39637 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. from abc import ABC, abstractmethod +from typing import Optional import torch import torch.distributed as dist @@ -49,12 +50,15 @@ class FusedMoEPrepareAndFinalize(ABC): is_deepseek_v3_r1) @abstractmethod - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Prepare tensors before MoE computation. May involve: - Padding to align communication boundaries @@ -74,11 +78,14 @@ class FusedMoEPrepareAndFinalize(ABC): - processed hidden_states (may be padded/sliced/broadcasted) - processed router_logits (may be recomputed or broadcasted) - optional communication mask (e.g., mc2_mask for sparse ops) + - optional context metadata (e.g., saved split_hidden_states for finalization) """ raise NotImplementedError("Prepare not implemented.") - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: + def finalize(self, + hidden_states: torch.Tensor, + reduce_results: bool, + context_metadata: Optional[dict] = None) -> torch.Tensor: """ Finalize MoE output. May involve: - Gathering sliced tensors across TP ranks @@ -96,9 +103,102 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError("Finalize function not implemented.") -class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): +class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): """ - MoE communication strategy using MC2 (Memory-Centric Communication). + MoE communication strategy using All-to-All style slicing. + Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. + Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). + """ + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + """Restore original TP configuration (same as MC2).""" + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """ + Preparation steps: + 1. Pad hidden_states and router_logits to next multiple of TP size. + 2. If TP > 1, split along token dim and select current TP rank's slice. + 3. Save splits for later all-gather in finalize. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + + Returns: + Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + split_hidden_states = None + + if not (self.replace_allreduce or self.enable_shared_expert_dp): + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + context_metadata = {"split_hidden_states": split_hidden_states} + + return hidden_states, router_logits, None, context_metadata + + def finalize(self, + hidden_states: torch.Tensor, + reduce_results: bool, + context_metadata: Optional[dict] = None) -> torch.Tensor: + """ + Finalization steps: + 1. If TP > 1, all-gather slices to reconstruct full tensor. + 2. Unpad to original token count. + 3. Return [original_num_tokens, hidden_size] tensor. + + Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. + """ + assert context_metadata is not None + + split_hidden_states = context_metadata["split_hidden_states"] + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(split_hidden_states, dim=0) + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): + """ + MoE communication strategy using MC2, which is based on All2All. Hence, it inherits + All2All and share the same finalize method. Designed for Ascend or environments requiring explicit padding and slicing control. Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. """ @@ -116,12 +216,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Preparation steps: 1. Fetch `mc2_mask` and target padding length from forward context. @@ -132,10 +235,11 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. Returns: - Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. + Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded. """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp + split_hidden_states = None forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask if self.tp_size > 1: @@ -165,124 +269,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): dim=0) hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - self.split_hidden_states = split_hidden_states # Save for finalize - return hidden_states, router_logits, mc2_mask + context_metadata = {"split_hidden_states": split_hidden_states} - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """ - Finalization steps: - 1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor. - 2. Unpad to original token count if padding was applied. - 3. Return tensor with shape [original_num_tokens, hidden_size]. - - Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True. - """ - if not (self.enable_shared_expert_dp or self.replace_allreduce): - if self.tp_size > 1: - # All-gather across TP group - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - # TODO: It is a quick bugfix for the memory explosion issue in eager mode. - # If the cache is not cleared after `self.split_hidden_states` is created, - # it can lead to the memory explosion in eager mode. - del self.split_hidden_states - - # Unpad if necessary - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - -class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): - """ - MoE communication strategy using All-to-All style slicing. - Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. - Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). - """ - - def __init__(self, moe_config: FusedMoEConfig): - super().__init__(moe_config) - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - """Restore original TP configuration (same as MC2).""" - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Preparation steps: - 1. Pad hidden_states and router_logits to next multiple of TP size. - 2. If TP > 1, split along token dim and select current TP rank's slice. - 3. Save splits for later all-gather in finalize. - - Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. - - Returns: - Tuple of (hidden_states, router_logits, None) — no mask used in All2All. - """ - self.replace_allreduce = replace_allreduce - self.enable_shared_expert_dp = enable_shared_expert_dp - - if not (self.replace_allreduce or self.enable_shared_expert_dp): - self.num_tokens, _ = hidden_states.shape - pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - self.split_hidden_states = split_hidden_states - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - - return hidden_states, router_logits, None - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """ - Finalization steps: - 1. If TP > 1, all-gather slices to reconstruct full tensor. - 2. Unpad to original token count. - 3. Return [original_num_tokens, hidden_size] tensor. - - Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. - """ - if not (self.enable_shared_expert_dp or self.replace_allreduce): - if self.tp_size > 1: - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - # TODO: It is a quick bugfix for the memory explosion issue in eager mode. - # If the cache is not cleared after `self.split_hidden_states` is created, - # it can lead to the memory explosion in eager mode. - del self.split_hidden_states - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states + return hidden_states, router_logits, mc2_mask, context_metadata class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): @@ -292,12 +282,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): Uses `max_tokens_across_dp` from forward_context for padding alignment. """ - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Preparation steps: 1. Fetch max token count across DP group from forward context. @@ -305,7 +298,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): 3. All-gather across DP group to form global input tensor. Returns: - Tuple of (global_hidden_states, global_router_logits, None) + Tuple of (global_hidden_states, global_router_logits, None, None) """ self.enable_shared_expert_dp = enable_shared_expert_dp @@ -331,10 +324,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) - return hidden_states, router_logits, None + return hidden_states, router_logits, None, None - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: + def finalize(self, + hidden_states: torch.Tensor, + reduce_results: bool, + context_metadata: Optional[dict] = None) -> torch.Tensor: """ Finalization steps: 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. @@ -395,19 +390,22 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): get_dp_group().broadcast(buffer[start:end, :], idx) return buffer - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Preparation steps: 1. Fetch cumulative token boundaries from forward context. 2. Multicast hidden_states and router_logits to form global tensors. Returns: - Tuple of (global_hidden_states, global_router_logits, None) + Tuple of (global_hidden_states, global_router_logits, None, None) """ self.enable_shared_expert_dp = enable_shared_expert_dp @@ -422,10 +420,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): router_logits = self._naive_multicast( router_logits, self.cu_tokens_across_dp_cpu) - return hidden_states, router_logits, None + return hidden_states, router_logits, None, None - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: + def finalize(self, + hidden_states: torch.Tensor, + reduce_results: bool, + context_metadata: Optional[dict] = None) -> torch.Tensor: """ Finalization steps: 1. If DP > 1 and not shared expert: diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index d1d0c1aa0..a83644338 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -57,28 +57,31 @@ class MoECommMethod(ABC): self.model_type = get_current_vllm_config( ).model_config.hf_config.model_type self.moe_config = moe_config - self.mc2_mask = None self.token_dispatcher = self._get_token_dispatcher() self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( ) - def prepare(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False, - gate=None) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( + def prepare( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: + hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare( hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, gate) - self.mc2_mask = mc2_mask - return hidden_states, router_logits + return hidden_states, router_logits, mc2_mask, context_metadata - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: + def finalize(self, + hidden_states: torch.Tensor, + reduce_results: bool, + context_metadata: Optional[dict] = None) -> torch.Tensor: hidden_states = self.fused_moe_prepare_finalize.finalize( - hidden_states, reduce_results) + hidden_states, reduce_results, context_metadata) return hidden_states def fused_experts( @@ -108,7 +111,8 @@ class MoECommMethod(ABC): log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, need_trans: bool = False, - dynamic_eplb: bool = False): + dynamic_eplb: bool = False, + mc2_mask: torch.Tensor = None): # Check constraints assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 @@ -127,12 +131,12 @@ class MoECommMethod(ABC): shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, - mc2_mask=self.mc2_mask, + mc2_mask=mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, w1=w1, @@ -152,7 +156,7 @@ class MoECommMethod(ABC): dynamic_eplb=dynamic_eplb) final_hidden_states = self.token_dispatcher.token_combine( - hidden_states=mlp_output) + hidden_states=mlp_output, context_metadata=context_metadata) if dynamic_eplb: return (final_hidden_states, group_list_type, expert_tokens) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 3dd799a42..9e4f2200b 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -75,6 +75,7 @@ class MoETokenDispatcher(ABC): @abstractmethod def token_combine(self, hidden_states: torch.Tensor, + context_metadata: dict, bias: torch.Tensor = None): raise NotImplementedError("Combine function not implemented.") @@ -102,16 +103,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. self.need_expert_scale = is_hierarchical_communication_enabled() - self.output = None - self.assist_info_for_combine = None - self.ep_recv_counts = None - self.shared_act = None - self.topk_ids = None - self.topk_weights = None - self.shared_experts = None - self.mc2_mask = None self.with_quant = False - self.expand_scales = None def get_dispatch_mc2_kwargs( self, @@ -119,6 +111,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): topk_weights: torch.Tensor, topk_ids: torch.Tensor, expert_map: torch.Tensor, + mc2_mask: torch.Tensor, global_redundant_expert_num: int = 0, ): if self.with_quant: @@ -155,7 +148,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage1_kwargs.update({ - "x_active_mask": self.mc2_mask, + "x_active_mask": mc2_mask, }) if self.need_expert_scale: stage1_kwargs.update({ @@ -166,99 +159,121 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + ): self.with_quant = with_quant - self.expert_map = expert_map - self.topk_ids = topk_ids - self.topk_weights = topk_weights - self.shared_experts = shared_experts - self.mc2_mask = mc2_mask + + # Apply log2phy if needed + if log2phy is not None: + topk_ids = log2phy[topk_ids] kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, topk_ids, expert_map, + mc2_mask, global_redundant_expert_num) - self.output = torch_npu.npu_moe_distribute_dispatch_v2( + output = torch_npu.npu_moe_distribute_dispatch_v2( **kwargs_mc2 ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \ - self.ep_recv_counts, _, self.expand_scales = self.output[0:7] + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ + ep_recv_counts, _, expand_scales = output[0:7] - if self.with_quant: + # Handle shared experts (store intermediate results in local vars, not self) + shared_act = None + swiglu_out_scale = None + if with_quant: if shared_experts is not None: share_up_out, _ = shared_experts.gate_up_proj( (quantized_x_for_share, dynamic_scale_for_share)) shared_gate_up, shared_dequant_scale = share_up_out[ 0], share_up_out[1] - shared_act_out = shared_experts.act_fn( (shared_gate_up, shared_dequant_scale)) - self.shared_act, self.swiglu_out_scale = \ - shared_act_out[0], shared_act_out[1] - + shared_act, swiglu_out_scale = shared_act_out[ + 0], shared_act_out[1] else: if shared_experts is not None: shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - self.shared_act = shared_experts.act_fn(shared_gate_up) - group_list_type = 0 + shared_act = shared_experts.act_fn(shared_gate_up) + + context_metadata = { + "topk_ids": topk_ids, + "topk_weights": topk_weights, + "mc2_mask": mc2_mask, + "expert_map": expert_map, + "ep_recv_counts": ep_recv_counts, + "assist_info_for_combine": assist_info_for_combine, + "shared_experts": shared_experts, + "shared_act": shared_act, + "swiglu_out_scale": swiglu_out_scale, + "expand_scales": expand_scales + } + return { - "group_list_type": group_list_type, + "group_list_type": 0, "hidden_states": expand_x, "group_list": expert_token_nums, "dynamic_scale": dynamic_scale, + "context_metadata": context_metadata } - def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): - assert self.expert_map is not None - assert self.topk_weights is not None - assert self.topk_ids is not None - assert self.output is not None - moe_expert_num = len(self.expert_map) - # moeCombine + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, + context_metadata: dict): + expert_map = context_metadata["expert_map"] + topk_ids = context_metadata["topk_ids"] + topk_weights = context_metadata["topk_weights"] + ep_recv_counts = context_metadata["ep_recv_counts"] + assist_info_for_combine = context_metadata["assist_info_for_combine"] + mc2_mask = context_metadata["mc2_mask"] + expand_scales = context_metadata["expand_scales"] + + assert expert_map is not None + moe_expert_num = len(expert_map) + kwargs_mc2 = { "expand_x": hidden_states, - "expert_ids": self.topk_ids, - "expert_scales": self.topk_weights.to(torch.float32), + "expert_ids": topk_ids, + "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, } + if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) else: - tp_recv_counts = self.output[5] + tp_recv_counts = ep_recv_counts + stage3_kwargs = { - "ep_send_counts": self.ep_recv_counts, + "ep_send_counts": ep_recv_counts, "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, - "expand_scales": self.expand_scales, + "expand_scales": expand_scales, } + if self.enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - self.assist_info_for_combine, - }) + stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine else: - stage3_kwargs.update({ - "expand_idx": self.assist_info_for_combine, - }) + stage3_kwargs["expand_idx"] = assist_info_for_combine + if self.need_extra_args: stage3_kwargs.update({ "tp_send_counts": tp_recv_counts, @@ -266,45 +281,40 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_world_size": 1, "tp_rank_id": 0, }) + if self.a3_need_extra_args and self.enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": self.mc2_mask, - }) + stage3_kwargs["x_active_mask"] = mc2_mask + kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_combine(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): - kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) + def token_combine( + self, + hidden_states: torch.Tensor, + context_metadata: dict, + bias: torch.Tensor = None, + ): + assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." - # these values are no longer used, so they need to be set to None for memory release. - self.output = None - self.assist_info_for_combine = None - self.ep_recv_counts = None - self.topk_ids = None - self.topk_weights = None - self.mc2_mask = None - self.expert_map = None - self.expand_scales = None + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, + context_metadata) + combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ + if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - if self.shared_experts is None: - return hidden_states + # Handle shared experts from metadata + shared_experts = context_metadata["shared_experts"] + if shared_experts is None: + return combined_output + + shared_act = context_metadata["shared_act"] + if self.with_quant: + swiglu_out_scale = context_metadata["swiglu_out_scale"] + shared_hidden_states, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) else: - if self.with_quant: - shared_hidden_states, _ = self.shared_experts.down_proj( - (self.shared_act, self.swiglu_out_scale)) - else: - shared_hidden_states, _ = self.shared_experts.down_proj( - self.shared_act) - self.shared_act = None - self.shared_experts = None - self.swiglu_out_scale = None - return hidden_states, shared_hidden_states + shared_hidden_states, _ = shared_experts.down_proj(shared_act) + + return combined_output, shared_hidden_states class TokenDispatcherWithAllGather(MoETokenDispatcher): @@ -314,14 +324,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.apply_router_weight_on_input = False self.max_num_tokens = kwargs.get("max_num_tokens") self.num_experts_local = kwargs.get("num_local_experts", 0) - self.sorted_weights = None - self.expanded_row_idx = None - self.sorted_token_indices = None self.original_shape = None - self.mask = None - self.expert_map = None - self.topk_weights = None - self.topk_ids = None self.with_quant = False def token_dispatch(self, @@ -341,9 +344,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() - self.expert_map = expert_map - self.topk_weights = topk_weights - self.topk_ids = topk_ids self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 @@ -357,7 +357,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): if expert_map is not None: global_num_experts = len(expert_map) mask = (expert_map[topk_ids] != -1) - self.topk_weights = topk_weights * mask + topk_weights = topk_weights * mask first_expert_idx = get_ep_group( ).rank_in_group * self.num_experts_local last_expert_idx = first_expert_idx + self.num_experts_local @@ -366,7 +366,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): last_expert_idx = self.num_experts_local global_num_experts = self.num_experts_local - sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( + sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = ( torch_npu.npu_moe_init_routing_v2( hidden_states, topk_ids, @@ -379,29 +379,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): )) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode + context_metadata = { + "topk_weights": topk_weights, + "expanded_row_idx": expanded_row_idx + } return { "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": expert_tokens, "dynamic_scale": pertoken_scale if self.with_quant else None, + "context_metadata": context_metadata } def token_combine(self, hidden_states: torch.Tensor, + context_metadata: dict, bias: torch.Tensor = None): assert self.original_shape is not None final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, - sorted_indices=torch.abs(self.expanded_row_idx), - probs=self.topk_weights) + sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), + probs=context_metadata["topk_weights"]) if len(self.original_shape) == 3: final_hidden_states = final_hidden_states.view(self.original_shape) # these values are no longer used, so they need to be set to None for memory release. - self.expert_map = None - self.topk_weights = None - self.topk_ids = None - self.expanded_row_idx = None return final_hidden_states @@ -450,11 +452,12 @@ class TokenDispatcherWithMoge(MoETokenDispatcher): "group_list_type": group_list_type, "hidden_states": sorted_hidden_states, "group_list": group_list, - "topk_scales": topk_scales, + "topk_scales": topk_scales } def token_combine(self, hidden_states: torch.Tensor, + context_metadata: dict, bias: torch.Tensor = None): unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( torch.int32) @@ -478,19 +481,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.num_local_experts = kwargs.get("num_local_experts", 0) self.hidden_shape = None - self.topk_weights = None - self.input_splits = None - self.output_splits = None self.hidden_shape_before_permute = None - # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent - # to each local expert by all ranks. - self.num_global_tokens_per_local_expert = None - - # cached intermediate tensors. - self.tokens_per_expert = None - self.global_input_tokens_local_experts_indices = None - assert self.num_local_experts > 0, "Expected at least one expert" if self.num_local_experts > 1: self.expert_ids_per_ep_rank = torch.tensor( @@ -512,96 +504,116 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.local_expert_indices[i + 1] - 1), "local_expert_indices must be continuous" - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + ): self.with_quant = with_quant self.hidden_shape = hidden_states.shape - self.topk_weights = topk_weights - assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" - assert topk_ids.dim() == 2, "Expected 2D tensor for routing map" if log2phy is not None: topk_ids = log2phy[topk_ids] - permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess( - hidden_states, topk_ids) - self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + ( + permutated_local_input_tokens, + reversed_local_input_permutation_mapping, + tokens_per_expert, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + global_input_tokens_local_experts_indices, + ) = self._dispatch_preprocess(hidden_states, topk_ids) dynamic_scale_after_all2all = None if self.with_quant: permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( permutated_local_input_tokens) - _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( - dynamic_scale, - self.output_splits, - self.input_splits, - self.ep_group, - ) + dynamic_scale, output_splits, input_splits, self.ep_group) permute2_ep_all_to_all_handle.wait() dynamic_scale.untyped_storage().resize_(0) _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( - permutated_local_input_tokens, - self.output_splits, - self.input_splits, - self.ep_group, - ) + permutated_local_input_tokens, output_splits, input_splits, + self.ep_group) permute1_ep_all_to_all_handle.wait() permutated_local_input_tokens.untyped_storage().resize_(0) - global_input_tokens, dynamic_scale = self._dispatch_postprocess( - global_input_tokens, dynamic_scale_after_all2all) + # Postprocess + global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess( + global_input_tokens, dynamic_scale_after_all2all, + global_input_tokens_local_experts_indices) + + context_metadata = { + "input_splits": + input_splits, + "output_splits": + output_splits, + "topk_weights": + topk_weights, + "reversed_local_input_permutation_mapping": + reversed_local_input_permutation_mapping, + "reversed_global_input_permutation_mapping": + reversed_global_input_permutation_mapping + } + return { "hidden_states": global_input_tokens, "group_list": tokens_per_expert, - "dynamic_scale": dynamic_scale, - "group_list_type": 1 + "group_list_type": 1, + "dynamic_scale": dynamic_scale_final, + "context_metadata": context_metadata, } - def token_combine(self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None): + def token_combine( + self, + hidden_states: torch.Tensor, + context_metadata: dict, + bias: torch.Tensor = None, + ): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." - hidden_states = self._combine_preprocess(hidden_states) + # 1. Preprocess using metadata + hidden_states = self._combine_preprocess(hidden_states, + context_metadata) - # Perform expert parallel AlltoAll communication - # hidden_states: [SEQL, H] -> [SEQL, H/TP] + # 2. AllToAll _, permutated_local_input_tokens, handle = async_all_to_all( - hidden_states, self.input_splits, self.output_splits, - self.ep_group) + hidden_states, + context_metadata["input_splits"], + context_metadata["output_splits"], + self.ep_group, + ) handle.wait() hidden_states.untyped_storage().resize_(0) - output = self._combine_postprocess(permutated_local_input_tokens) - - # these values are no longer used, so they need to be set to None for memory release. - self.input_splits = None - self.output_splits = None - self.num_global_tokens_per_local_expert = None - self.topk_weights = None - self.reversed_local_input_permutation_mapping = None - self.reversed_global_input_permutation_mapping = None - self.global_input_tokens_local_experts_indices = None + # 3. Postprocess using metadata + output = self._combine_postprocess(permutated_local_input_tokens, + context_metadata) return output def _dispatch_preprocess(self, hidden_states, topk_ids): assert self.hidden_shape is not None - hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) - tokens_per_expert = self._preprocess(topk_ids) + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + ( + tokens_per_expert, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + global_input_tokens_local_experts_indices, + ) = self._preprocess(topk_ids) self.hidden_shape_before_permute = hidden_states.shape @@ -610,82 +622,88 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): indices=topk_ids, num_out_tokens=self.num_out_tokens, ) - return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert - def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor: + return ( + permutated_local_input_tokens, + reversed_local_input_permutation_mapping, + tokens_per_expert, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + global_input_tokens_local_experts_indices, + ) + + def _preprocess(self, topk_ids: torch.Tensor): num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts) ep_size = self.ep_size - - # Dropless self.num_out_tokens = topk_ids.numel() - # =================================================== - # Calculate input_splits, output_splits for alltoall-v. - # =================================================== - self.input_splits = (num_local_tokens_per_expert.reshape( + input_splits = (num_local_tokens_per_expert.reshape( ep_size, self.num_local_experts).sum(axis=1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( num_local_tokens_per_expert, group=self.ep_group).reshape(ep_size, self.num_experts) - self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ 0]:self.local_expert_indices[-1] + 1] - if self.num_global_tokens_per_local_expert is None: + if num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before sum.") - self.output_splits = (self.num_global_tokens_per_local_expert.sum( - axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) - num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( - axis=0) - # =================================================== - # num_global_tokens_per_expert: [ep_size, num_experts] - # num_global_tokens_per_local_expert: [ep_size, num_local_experts] - # num_tokens_per_local_expert: [num_local_experts] - # =================================================== + output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to( + torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( + axis=0) + + global_input_tokens_local_experts_indices = None if self.num_local_experts > 1: - if self.num_global_tokens_per_local_expert is None: + if num_global_tokens_per_local_expert is None: raise ValueError( "num_global_tokens_per_local_expert must be set before operations." ) - self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, - self.num_global_tokens_per_local_expert.ravel()) + num_global_tokens_per_local_expert.ravel()) else: - # TODO: This full synchronization can be a performance bottleneck. - # A more granular sync (e.g., blocking D2H copies) should be investigated. torch.npu.synchronize() - return num_tokens_per_local_expert + return ( + num_tokens_per_local_expert, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + global_input_tokens_local_experts_indices, + ) - def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None): + def _dispatch_postprocess(self, global_input_tokens, + dynamic_scale_after_all2all, + global_input_tokens_local_experts_indices): # Early return if no local experts or no tokens if self.num_local_experts <= 1: - return global_input_tokens, None + return global_input_tokens, dynamic_scale_after_all2all, None # Handle quantized case if self.with_quant: - assert self.global_input_tokens_local_experts_indices is not None, \ - "global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess" - expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze( + assert global_input_tokens_local_experts_indices is not None, \ + "global_input_tokens_local_experts_indices must be provided" + expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze( -1) - active_num = self.global_input_tokens_local_experts_indices.numel() + active_num = global_input_tokens_local_experts_indices.numel() - # Handle case with no active tokens if active_num <= 0: - self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices - return global_input_tokens, dynamic_scale + reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices + return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping - # Process with active tokens - global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( + global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2( global_input_tokens, expert_idx_2d, - scale=dynamic_scale, + scale=dynamic_scale_after_all2all, active_num=active_num, expert_capacity=0, expert_num=self.num_local_experts, @@ -693,32 +711,34 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): expert_tokens_num_flag=True, active_expert_range=[0, self.num_local_experts], quant_mode=-1, - row_idx_type=0) - return global_input_tokens, expanded_scale + row_idx_type=0, + ) + return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping - # Handle non-quantized case - global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( - global_input_tokens, - self.global_input_tokens_local_experts_indices) - return global_input_tokens, None + # Non-quantized case + global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, global_input_tokens_local_experts_indices) + return global_input_tokens, None, reversed_global_input_permutation_mapping - def _combine_preprocess(self, hidden_states): + def _combine_preprocess(self, hidden_states: torch.Tensor, + context_metadata: dict) -> torch.Tensor: # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + rev_global = context_metadata[ + "reversed_global_input_permutation_mapping"] hidden_states = torch_npu.npu_moe_token_unpermute( - hidden_states, self.reversed_global_input_permutation_mapping) - + hidden_states, rev_global) return hidden_states - def _combine_postprocess(self, permutated_local_input_tokens): + def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, + context_metadata: dict) -> torch.Tensor: # Unpermutation 1: AlltoAll output to output output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=self.reversed_local_input_permutation_mapping.to( - torch.int32), - probs=self.topk_weights, - restore_shape=self.hidden_shape_before_permute) - - # Reshape the output tensor + sorted_indices=context_metadata[ + "reversed_local_input_permutation_mapping"].to(torch.int32), + probs=context_metadata["topk_weights"], + restore_shape=self.hidden_shape_before_permute, + ) output = output.view(self.hidden_shape) return output