[MoE] [Refactor] Remove manual memory cleanup (#3365)

### What this PR does / why we need it?
1. Replace manual memory cleanup with passing parameter.
2. FusedMoEPrepareAndFinalizeWithMC2 inherits All2All avoid duplicated
code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-15 12:36:24 +08:00
committed by GitHub
parent 4e720936d8
commit 4f937f561d
8 changed files with 562 additions and 492 deletions

View File

@ -137,6 +137,7 @@ def test_token_dispatcher_with_all_gather(
sorted_hidden_states = dispatch_output["hidden_states"] sorted_hidden_states = dispatch_output["hidden_states"]
group_list = dispatch_output["group_list"] group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1) group_list_type = dispatch_output.get("group_list_type", 1)
context_metadata = dispatch_output["context_metadata"]
expert_output = apply_mlp(hidden_states=sorted_hidden_states, expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local, w1=w1_local,
@ -144,8 +145,10 @@ def test_token_dispatcher_with_all_gather(
group_list=group_list, group_list=group_list,
group_list_type=group_list_type) group_list_type=group_list_type)
combined_output = dispatcher.token_combine(hidden_states=expert_output, combined_output = dispatcher.token_combine(
bias=None) hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map) expert_map)
@ -215,6 +218,7 @@ def test_token_dispatcher_with_all_gather_quant(
group_list = dispatch_output["group_list"] group_list = dispatch_output["group_list"]
group_list_type = dispatch_output.get("group_list_type", 1) group_list_type = dispatch_output.get("group_list_type", 1)
dynamic_scale = dispatch_output["dynamic_scale"] dynamic_scale = dispatch_output["dynamic_scale"]
context_metadata = dispatch_output["context_metadata"]
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1, w1=w1,
@ -225,8 +229,10 @@ def test_token_dispatcher_with_all_gather_quant(
group_list_type=group_list_type, group_list_type=group_list_type,
dynamic_scale=dynamic_scale, dynamic_scale=dynamic_scale,
with_quant=True) with_quant=True)
combined_output = dispatcher.token_combine(hidden_states=expert_output, combined_output = dispatcher.token_combine(
bias=None) hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.shape == (m, k) assert combined_output.shape == (m, k)
gc.collect() gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()

View File

@ -44,7 +44,8 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, mask = layer.prepare(hidden_states, router_logits) h_out, r_out, mask, context_metadata = layer.prepare(
hidden_states, router_logits)
# Check padding and split # Check padding and split
self.assertEqual(h_out.shape[0], 4) self.assertEqual(h_out.shape[0], 4)
@ -52,7 +53,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
self.assertEqual(mask.tolist(), [1, 0, 1]) self.assertEqual(mask.tolist(), [1, 0, 1])
# Finalize # Finalize
result = layer.finalize(h_out, reduce_results=False) result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
@ -77,10 +80,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(4, 8) hidden_states = torch.randn(4, 8)
router_logits = torch.randn(4, 2) router_logits = torch.randn(4, 2)
h_out, r_out, mask = layer.prepare(hidden_states, h_out, r_out, mask, context_metadata = layer.prepare(
router_logits, hidden_states,
enable_shared_expert_dp=False, router_logits,
replace_allreduce=False) enable_shared_expert_dp=False,
replace_allreduce=False)
# With TP=2, should split into 2 parts # With TP=2, should split into 2 parts
self.assertEqual(h_out.shape[0], 2) self.assertEqual(h_out.shape[0], 2)
@ -96,7 +100,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out), torch.zeros_like(h_out),
torch.zeros_like(h_out) torch.zeros_like(h_out)
] ]
final_result = layer.finalize(h_out, reduce_results=False) final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back to original size # Should concat back to original size
self.assertEqual(final_result.shape[0], 4) self.assertEqual(final_result.shape[0], 4)
@ -112,12 +118,15 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out, _ = layer.prepare(hidden_states, router_logits) h_out, r_out, _, context_metadata = layer.prepare(
hidden_states, router_logits)
# Pad to tp_size=1, so no change # Pad to tp_size=1, so no change
self.assertEqual(h_out.shape[0], 3) self.assertEqual(h_out.shape[0], 3)
result = layer.finalize(h_out, reduce_results=False) result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@patch( @patch(
@ -133,10 +142,11 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
hidden_states = torch.randn(2, 8) hidden_states = torch.randn(2, 8)
router_logits = torch.randn(2, 2) router_logits = torch.randn(2, 2)
h_out, r_out, _ = layer.prepare(hidden_states, h_out, r_out, _, context_metadata = layer.prepare(
router_logits, hidden_states,
enable_shared_expert_dp=False, router_logits,
replace_allreduce=False) enable_shared_expert_dp=False,
replace_allreduce=False)
# Split due to TP=2 # Split due to TP=2
self.assertEqual(h_out.shape[0], 1) self.assertEqual(h_out.shape[0], 1)
@ -152,7 +162,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
torch.zeros_like(h_out), torch.zeros_like(h_out),
torch.zeros_like(h_out) torch.zeros_like(h_out)
] ]
final_result = layer.finalize(h_out, reduce_results=False) final_result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
# Should concat back # Should concat back
self.assertEqual(final_result.shape[0], 2) self.assertEqual(final_result.shape[0], 2)
@ -195,9 +207,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate = MagicMock() mock_gate = MagicMock()
mock_gate.return_value = (router_logits.repeat(2, 1), None) mock_gate.return_value = (router_logits.repeat(2, 1), None)
h_out, r_out, _ = layer.prepare(hidden_states, h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
router_logits, router_logits,
gate=mock_gate) gate=mock_gate)
# After all-gather with DP=2, should double the batch size # After all-gather with DP=2, should double the batch size
self.assertEqual(h_out.shape[0], 12) self.assertEqual(h_out.shape[0], 12)
@ -209,7 +221,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
return tensor[:3] return tensor[:3]
mock_dp_group.reduce_scatter = mock_reduce_scatter_func mock_dp_group.reduce_scatter = mock_reduce_scatter_func
result = layer.finalize(h_out, reduce_results=False) result = layer.finalize(h_out,
reduce_results=False,
context_metadata=context_metadata)
self.assertEqual(result.shape[0], 3) self.assertEqual(result.shape[0], 3)
@ -263,9 +277,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
mock_gate.return_value = (torch.randn(7, 2), None) mock_gate.return_value = (torch.randn(7, 2), None)
# Run prepare # Run prepare
h_out, r_out, _ = layer.prepare(hidden_states, h_out, r_out, _, _ = layer.prepare(hidden_states,
router_logits, router_logits,
gate=mock_gate) gate=mock_gate)
# Should be global tensor: [7, 8] and [7, 2] # Should be global tensor: [7, 8] and [7, 2]
self.assertEqual(h_out.shape, (7, 8)) self.assertEqual(h_out.shape, (7, 8))

View File

@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None) torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@ -59,15 +59,18 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits) h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None) hidden_states, router_logits, False, False, None)
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, reduce_results=True) comm_impl.finalize(h_out,
mock_pf_instance.finalize.assert_called_once_with(h_out, True) reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@ -90,7 +93,8 @@ class TestMoECommMethod(TestBase):
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), torch.randn(4, 2),
torch.tensor([1, 0, 1, 0])) torch.tensor([1, 0, 1,
0]), None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@ -104,15 +108,18 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits) h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(
hidden_states, router_logits, False, False, None) hidden_states, router_logits, False, False, None)
# Test finalize method # Test finalize method
comm_impl.finalize(h_out, reduce_results=True) comm_impl.finalize(h_out,
mock_pf_instance.finalize.assert_called_once_with(h_out, True) reduce_results=True,
context_metadata=context_metadata)
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
@ -135,7 +142,7 @@ class TestMoECommMethod(TestBase):
# Mock prepare finalize # Mock prepare finalize
mock_pf_instance = MagicMock() mock_pf_instance = MagicMock()
mock_pf_instance.prepare.return_value = (torch.randn(4, 8), mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
torch.randn(4, 2), None) torch.randn(4, 2), None, None)
mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_pf_instance.finalize.return_value = torch.randn(4, 8)
mock_prepare_finalize.return_value = mock_pf_instance mock_prepare_finalize.return_value = mock_pf_instance
@ -149,7 +156,8 @@ class TestMoECommMethod(TestBase):
# Test prepare method # Test prepare method
hidden_states = torch.randn(3, 8) hidden_states = torch.randn(3, 8)
router_logits = torch.randn(3, 2) router_logits = torch.randn(3, 2)
h_out, r_out = comm_impl.prepare(hidden_states, router_logits) h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
hidden_states, router_logits)
# Verify prepare was called with correct arguments # Verify prepare was called with correct arguments
mock_pf_instance.prepare.assert_called_once_with( mock_pf_instance.prepare.assert_called_once_with(

View File

@ -77,9 +77,10 @@ class TestTokenDispatcherWithMC2(TestBase):
topk_ids = torch.randint(0, 8, (10, 1)) topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1) topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
kwargs = self.dispatcher.get_dispatch_mc2_kwargs( kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map) hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
self.assertIn("x", kwargs) self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs) self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8) self.assertEqual(kwargs["moe_expert_num"], 8)
@ -123,36 +124,64 @@ class TestTokenDispatcherWithMC2(TestBase):
def test_get_combine_mc_kwargs_with_quant(self): def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128) hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) # 注意:应为 float不是 int
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
mc2_mask = None
assist_info_for_combine = torch.arange(10) # mock 值
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": mc2_mask,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
}
self.dispatcher.need_extra_args = True self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states) kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
context_metadata)
self.assertIn("tp_send_counts", kwargs) self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self): def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock() shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( shared_experts.down_proj.return_value = (torch.randn(10, 128),
10, 128), torch.tensor(1.0)) torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
assist_info_for_combine = torch.arange(10)
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"mc2_mask": None,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": None,
"shared_experts": shared_experts,
"shared_act": torch.randn(10, 128),
"swiglu_out_scale": torch.randn(10, 1),
}
self.dispatcher.with_quant = True self.dispatcher.with_quant = True
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2", with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)): return_value=torch.randn(10, 128)):
self.dispatcher.token_combine(self.hidden_states) result = self.dispatcher.token_combine(hidden_states,
context_metadata)
self.assertIsInstance(result, tuple)
class TestTokenDispatcherWithAllGather(TestBase): class TestTokenDispatcherWithAllGather(TestBase):
@ -264,35 +293,26 @@ class TestTokenDispatcherWithAllGather(TestBase):
self.assertEqual(results["group_list_type"], 1) self.assertEqual(results["group_list_type"], 1)
def test_token_combine_with_expert_map(self): def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128) hidden_states = torch.randn(6, 128)
context_metadata = {
final_hidden_states = self.dispatcher.token_combine(hidden_states) "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
}
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.assertEqual(final_hidden_states.shape, (6, 128)) self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_combine_without_expert_map(self): def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128) hidden_states = torch.randn(6, 128)
context_metadata = {
final_hidden_states = self.dispatcher.token_combine(hidden_states) "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
# Verify npu_moe_finalize_routing is called }
self.dispatcher.original_shape = (6, 128)
final_hidden_states = self.dispatcher.token_combine(
hidden_states, context_metadata)
self.mock_npu_moe_token_unpermute.assert_called_once() self.mock_npu_moe_token_unpermute.assert_called_once()
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
self.assertEqual(final_hidden_states.shape, (6, 128)) self.assertEqual(final_hidden_states.shape, (6, 128))
def test_token_dispatch_with_router_weight(self): def test_token_dispatch_with_router_weight(self):
@ -418,25 +438,21 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
self.assertEqual(result["group_list_type"], 1) self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self): def test_token_combine(self):
hidden_states = torch.randn(16, 16)
context_metadata = {
"input_splits": [4, 4],
"output_splits": [4, 4],
"topk_weights": torch.rand(8, 4),
"reversed_local_input_permutation_mapping": torch.arange(8),
"reversed_global_input_permutation_mapping": torch.arange(16),
}
self.dispatcher.hidden_shape = (8, 16) self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16) self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor( self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32) [0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1] self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
output = self.dispatcher.token_combine(hidden_states, context_metadata)
self.assertIsNotNone(output) self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16)) self.assertEqual(output.shape, (8, 16))

View File

@ -283,7 +283,7 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance = forward_context.in_profile_run enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context() forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled, replace_allreduce=forward_context.sp_enabled,
@ -311,7 +311,8 @@ class AscendFusedMoE(FusedMoE):
shared_experts=None, shared_experts=None,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy, log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num) global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
if isinstance(final_hidden_states, tuple): if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states final_hidden_states, group_list_type, expert_tokens = final_hidden_states
@ -322,7 +323,8 @@ class AscendFusedMoE(FusedMoE):
final_hidden_states = forward_context.moe_comm_method.finalize( final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states, hidden_states=final_hidden_states,
reduce_results=self.reduce_results) reduce_results=self.reduce_results,
context_metadata=context_metadata)
return final_hidden_states return final_hidden_states

View File

@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -49,12 +50,15 @@ class FusedMoEPrepareAndFinalize(ABC):
is_deepseek_v3_r1) is_deepseek_v3_r1)
@abstractmethod @abstractmethod
def prepare(self, def prepare(
hidden_states: torch.Tensor, self,
router_logits: torch.Tensor, hidden_states: torch.Tensor,
enable_shared_expert_dp: bool = False, router_logits: torch.Tensor,
replace_allreduce: bool = False, enable_shared_expert_dp: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Prepare tensors before MoE computation. May involve: Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries - Padding to align communication boundaries
@ -74,11 +78,14 @@ class FusedMoEPrepareAndFinalize(ABC):
- processed hidden_states (may be padded/sliced/broadcasted) - processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted) - processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops) - optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization)
""" """
raise NotImplementedError("Prepare not implemented.") raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor, def finalize(self,
reduce_results: bool) -> torch.Tensor: hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalize MoE output. May involve: Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks - Gathering sliced tensors across TP ranks
@ -96,9 +103,102 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError("Finalize function not implemented.") raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
""" """
MoE communication strategy using MC2 (Memory-Centric Communication). MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"split_hidden_states": split_hidden_states}
return hidden_states, router_logits, None, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
assert context_metadata is not None
split_hidden_states = context_metadata["split_hidden_states"]
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
"""
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method.
Designed for Ascend or environments requiring explicit padding and slicing control. Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
""" """
@ -116,12 +216,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self, def prepare(
hidden_states: torch.Tensor, self,
router_logits: torch.Tensor, hidden_states: torch.Tensor,
enable_shared_expert_dp: bool = False, router_logits: torch.Tensor,
replace_allreduce: bool = False, enable_shared_expert_dp: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context. 1. Fetch `mc2_mask` and target padding length from forward context.
@ -132,10 +235,11 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns: Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded. Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
split_hidden_states = None
forward_context = get_forward_context() forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask mc2_mask = forward_context.mc2_mask
if self.tp_size > 1: if self.tp_size > 1:
@ -165,124 +269,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
dim=0) dim=0)
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states # Save for finalize
return hidden_states, router_logits, mc2_mask context_metadata = {"split_hidden_states": split_hidden_states}
def finalize(self, hidden_states: torch.Tensor, return hidden_states, router_logits, mc2_mask, context_metadata
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
@ -292,12 +282,15 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
Uses `max_tokens_across_dp` from forward_context for padding alignment. Uses `max_tokens_across_dp` from forward_context for padding alignment.
""" """
def prepare(self, def prepare(
hidden_states: torch.Tensor, self,
router_logits: torch.Tensor, hidden_states: torch.Tensor,
enable_shared_expert_dp: bool = False, router_logits: torch.Tensor,
replace_allreduce: bool = False, enable_shared_expert_dp: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Fetch max token count across DP group from forward context. 1. Fetch max token count across DP group from forward context.
@ -305,7 +298,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
3. All-gather across DP group to form global input tensor. 3. All-gather across DP group to form global input tensor.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None) Tuple of (global_hidden_states, global_router_logits, None, None)
""" """
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
@ -331,10 +324,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
router_logits = self.moe_config.dp_group.all_gather( router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0) router_logits, 0)
return hidden_states, router_logits, None return hidden_states, router_logits, None, None
def finalize(self, hidden_states: torch.Tensor, def finalize(self,
reduce_results: bool) -> torch.Tensor: hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group. 1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
@ -395,19 +390,22 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
get_dp_group().broadcast(buffer[start:end, :], idx) get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer return buffer
def prepare(self, def prepare(
hidden_states: torch.Tensor, self,
router_logits: torch.Tensor, hidden_states: torch.Tensor,
enable_shared_expert_dp: bool = False, router_logits: torch.Tensor,
replace_allreduce: bool = False, enable_shared_expert_dp: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Fetch cumulative token boundaries from forward context. 1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors. 2. Multicast hidden_states and router_logits to form global tensors.
Returns: Returns:
Tuple of (global_hidden_states, global_router_logits, None) Tuple of (global_hidden_states, global_router_logits, None, None)
""" """
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
@ -422,10 +420,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
router_logits = self._naive_multicast( router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu) router_logits, self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None return hidden_states, router_logits, None, None
def finalize(self, hidden_states: torch.Tensor, def finalize(self,
reduce_results: bool) -> torch.Tensor: hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
1. If DP > 1 and not shared expert: 1. If DP > 1 and not shared expert:

View File

@ -57,28 +57,31 @@ class MoECommMethod(ABC):
self.model_type = get_current_vllm_config( self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type ).model_config.hf_config.model_type
self.moe_config = moe_config self.moe_config = moe_config
self.mc2_mask = None
self.token_dispatcher = self._get_token_dispatcher() self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
) )
def prepare(self, def prepare(
hidden_states: torch.Tensor, self,
router_logits: torch.Tensor, hidden_states: torch.Tensor,
enable_shared_expert_dp: bool = False, router_logits: torch.Tensor,
replace_allreduce: bool = False, enable_shared_expert_dp: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]: replace_allreduce: bool = False,
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, gate) replace_allreduce, gate)
self.mc2_mask = mc2_mask return hidden_states, router_logits, mc2_mask, context_metadata
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor, def finalize(self,
reduce_results: bool) -> torch.Tensor: hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize( hidden_states = self.fused_moe_prepare_finalize.finalize(
hidden_states, reduce_results) hidden_states, reduce_results, context_metadata)
return hidden_states return hidden_states
def fused_experts( def fused_experts(
@ -108,7 +111,8 @@ class MoECommMethod(ABC):
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
need_trans: bool = False, need_trans: bool = False,
dynamic_eplb: bool = False): dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None):
# Check constraints # Check constraints
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
@ -127,12 +131,12 @@ class MoECommMethod(ABC):
shared_experts=shared_experts, shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share, quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share, dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=self.mc2_mask, mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8) with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1, w1=w1,
@ -152,7 +156,7 @@ class MoECommMethod(ABC):
dynamic_eplb=dynamic_eplb) dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine( final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output) hidden_states=mlp_output, context_metadata=context_metadata)
if dynamic_eplb: if dynamic_eplb:
return (final_hidden_states, group_list_type, expert_tokens) return (final_hidden_states, group_list_type, expert_tokens)

View File

@ -75,6 +75,7 @@ class MoETokenDispatcher(ABC):
@abstractmethod @abstractmethod
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None): bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.") raise NotImplementedError("Combine function not implemented.")
@ -102,16 +103,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance. # improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled() self.need_expert_scale = is_hierarchical_communication_enabled()
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False self.with_quant = False
self.expand_scales = None
def get_dispatch_mc2_kwargs( def get_dispatch_mc2_kwargs(
self, self,
@ -119,6 +111,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
expert_map: torch.Tensor, expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
): ):
if self.with_quant: if self.with_quant:
@ -155,7 +148,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
}) })
if self.a3_need_extra_args and self.enable_dispatch_v2: if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({ stage1_kwargs.update({
"x_active_mask": self.mc2_mask, "x_active_mask": mc2_mask,
}) })
if self.need_expert_scale: if self.need_expert_scale:
stage1_kwargs.update({ stage1_kwargs.update({
@ -166,99 +159,121 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2 return kwargs_mc2
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
log2phy: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, log2phy: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None, global_redundant_expert_num: int = 0,
quantized_x_for_share: Optional[Any] = None, shared_experts: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, dynamic_scale_for_share: Optional[Any] = None,
apply_router_weight_on_input: bool = False, mc2_mask: Optional[torch.Tensor] = None,
with_quant: bool = False): apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids # Apply log2phy if needed
self.topk_weights = topk_weights if log2phy is not None:
self.shared_experts = shared_experts topk_ids = log2phy[topk_ids]
self.mc2_mask = mc2_mask
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map, topk_ids, expert_map,
mc2_mask,
global_redundant_expert_num) global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2( output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2 **kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2) **kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream()) # comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \ expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
self.ep_recv_counts, _, self.expand_scales = self.output[0:7] ep_recv_counts, _, expand_scales = output[0:7]
if self.with_quant: # Handle shared experts (store intermediate results in local vars, not self)
shared_act = None
swiglu_out_scale = None
if with_quant:
if shared_experts is not None: if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj( share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share)) (quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[ shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1] 0], share_up_out[1]
shared_act_out = shared_experts.act_fn( shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale)) (shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \ shared_act, swiglu_out_scale = shared_act_out[
shared_act_out[0], shared_act_out[1] 0], shared_act_out[1]
else: else:
if shared_experts is not None: if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up) shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 0
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"mc2_mask": mc2_mask,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"shared_experts": shared_experts,
"shared_act": shared_act,
"swiglu_out_scale": swiglu_out_scale,
"expand_scales": expand_scales
}
return { return {
"group_list_type": group_list_type, "group_list_type": 0,
"hidden_states": expand_x, "hidden_states": expand_x,
"group_list": expert_token_nums, "group_list": expert_token_nums,
"dynamic_scale": dynamic_scale, "dynamic_scale": dynamic_scale,
"context_metadata": context_metadata
} }
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
assert self.expert_map is not None context_metadata: dict):
assert self.topk_weights is not None expert_map = context_metadata["expert_map"]
assert self.topk_ids is not None topk_ids = context_metadata["topk_ids"]
assert self.output is not None topk_weights = context_metadata["topk_weights"]
moe_expert_num = len(self.expert_map) ep_recv_counts = context_metadata["ep_recv_counts"]
# moeCombine assist_info_for_combine = context_metadata["assist_info_for_combine"]
mc2_mask = context_metadata["mc2_mask"]
expand_scales = context_metadata["expand_scales"]
assert expert_map is not None
moe_expert_num = len(expert_map)
kwargs_mc2 = { kwargs_mc2 = {
"expand_x": hidden_states, "expand_x": hidden_states,
"expert_ids": self.topk_ids, "expert_ids": topk_ids,
"expert_scales": self.topk_weights.to(torch.float32), "expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0, "expert_shard_type": 0,
"shared_expert_rank_num": 0, "shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num, "moe_expert_num": moe_expert_num,
"global_bs": 0, "global_bs": 0,
} }
if self.with_quant: if self.with_quant:
tp_recv_counts = torch.empty(1, tp_recv_counts = torch.empty(1,
dtype=torch.int32, dtype=torch.int32,
device=hidden_states.device) device=hidden_states.device)
else: else:
tp_recv_counts = self.output[5] tp_recv_counts = ep_recv_counts
stage3_kwargs = { stage3_kwargs = {
"ep_send_counts": self.ep_recv_counts, "ep_send_counts": ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name, "group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size, "ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id, "ep_rank_id": self.ep_rank_id,
"expand_scales": self.expand_scales, "expand_scales": expand_scales,
} }
if self.enable_dispatch_v2: if self.enable_dispatch_v2:
stage3_kwargs.update({ stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine
"assist_info_for_combine":
self.assist_info_for_combine,
})
else: else:
stage3_kwargs.update({ stage3_kwargs["expand_idx"] = assist_info_for_combine
"expand_idx": self.assist_info_for_combine,
})
if self.need_extra_args: if self.need_extra_args:
stage3_kwargs.update({ stage3_kwargs.update({
"tp_send_counts": tp_recv_counts, "tp_send_counts": tp_recv_counts,
@ -266,45 +281,40 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_world_size": 1, "tp_world_size": 1,
"tp_rank_id": 0, "tp_rank_id": 0,
}) })
if self.a3_need_extra_args and self.enable_dispatch_v2: if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({ stage3_kwargs["x_active_mask"] = mc2_mask
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2 return kwargs_mc2
def token_combine(self, def token_combine(
hidden_states: torch.Tensor, self,
bias: torch.Tensor = None): hidden_states: torch.Tensor,
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) context_metadata: dict,
hidden_states = torch_npu.npu_moe_distribute_combine_v2( bias: torch.Tensor = None,
**kwargs_mc2 ):
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
**kwargs_mc2)
# these values are no longer used, so they need to be set to None for memory release. kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
self.output = None context_metadata)
self.assist_info_for_combine = None combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
self.ep_recv_counts = None if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
self.topk_ids = None
self.topk_weights = None
self.mc2_mask = None
self.expert_map = None
self.expand_scales = None
if self.shared_experts is None: # Handle shared experts from metadata
return hidden_states shared_experts = context_metadata["shared_experts"]
if shared_experts is None:
return combined_output
shared_act = context_metadata["shared_act"]
if self.with_quant:
swiglu_out_scale = context_metadata["swiglu_out_scale"]
shared_hidden_states, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
else: else:
if self.with_quant: shared_hidden_states, _ = shared_experts.down_proj(shared_act)
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale)) return combined_output, shared_hidden_states
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
self.shared_act = None
self.shared_experts = None
self.swiglu_out_scale = None
return hidden_states, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher): class TokenDispatcherWithAllGather(MoETokenDispatcher):
@ -314,14 +324,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.apply_router_weight_on_input = False self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens") self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0) self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
self.original_shape = None self.original_shape = None
self.mask = None
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False self.with_quant = False
def token_dispatch(self, def token_dispatch(self,
@ -341,9 +344,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.original_shape = hidden_states.shape self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input: if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert (topk_weights.dim() == 2
@ -357,7 +357,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
if expert_map is not None: if expert_map is not None:
global_num_experts = len(expert_map) global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1) mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask topk_weights = topk_weights * mask
first_expert_idx = get_ep_group( first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local ).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local last_expert_idx = first_expert_idx + self.num_experts_local
@ -366,7 +366,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
last_expert_idx = self.num_experts_local last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local global_num_experts = self.num_experts_local
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = ( sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
torch_npu.npu_moe_init_routing_v2( torch_npu.npu_moe_init_routing_v2(
hidden_states, hidden_states,
topk_ids, topk_ids,
@ -379,29 +379,31 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
)) ))
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode group_list_type = 1 # `count` mode
context_metadata = {
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return { return {
"group_list_type": group_list_type, "group_list_type": group_list_type,
"hidden_states": sorted_hidden_states, "hidden_states": sorted_hidden_states,
"group_list": expert_tokens, "group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None, "dynamic_scale": pertoken_scale if self.with_quant else None,
"context_metadata": context_metadata
} }
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None): bias: torch.Tensor = None):
assert self.original_shape is not None assert self.original_shape is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute( final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states, permuted_tokens=hidden_states,
sorted_indices=torch.abs(self.expanded_row_idx), sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=self.topk_weights) probs=context_metadata["topk_weights"])
if len(self.original_shape) == 3: if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape) final_hidden_states = final_hidden_states.view(self.original_shape)
# these values are no longer used, so they need to be set to None for memory release. # these values are no longer used, so they need to be set to None for memory release.
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.expanded_row_idx = None
return final_hidden_states return final_hidden_states
@ -450,11 +452,12 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
"group_list_type": group_list_type, "group_list_type": group_list_type,
"hidden_states": sorted_hidden_states, "hidden_states": sorted_hidden_states,
"group_list": group_list, "group_list": group_list,
"topk_scales": topk_scales, "topk_scales": topk_scales
} }
def token_combine(self, def token_combine(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None): bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32) torch.int32)
@ -478,19 +481,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.num_local_experts = kwargs.get("num_local_experts", 0) self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert" assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1: if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor( self.expert_ids_per_ep_rank = torch.tensor(
@ -512,96 +504,116 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] - self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous" 1), "local_expert_indices must be continuous"
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
log2phy: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, log2phy: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None, global_redundant_expert_num: int = 0,
quantized_x_for_share: Optional[Any] = None, shared_experts: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, dynamic_scale_for_share: Optional[Any] = None,
apply_router_weight_on_input: bool = False, mc2_mask: Optional[torch.Tensor] = None,
with_quant: bool = False): apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
self.with_quant = with_quant self.with_quant = with_quant
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None: if log2phy is not None:
topk_ids = log2phy[topk_ids] topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess( (
hidden_states, topk_ids) permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None dynamic_scale_after_all2all = None
if self.with_quant: if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens) permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, dynamic_scale, output_splits, input_splits, self.ep_group)
self.output_splits,
self.input_splits,
self.ep_group,
)
permute2_ep_all_to_all_handle.wait() permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0) dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens, permutated_local_input_tokens, output_splits, input_splits,
self.output_splits, self.ep_group)
self.input_splits,
self.ep_group,
)
permute1_ep_all_to_all_handle.wait() permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0) permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess( # Postprocess
global_input_tokens, dynamic_scale_after_all2all) global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices)
context_metadata = {
"input_splits":
input_splits,
"output_splits":
output_splits,
"topk_weights":
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
}
return { return {
"hidden_states": global_input_tokens, "hidden_states": global_input_tokens,
"group_list": tokens_per_expert, "group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale, "group_list_type": 1,
"group_list_type": 1 "dynamic_scale": dynamic_scale_final,
"context_metadata": context_metadata,
} }
def token_combine(self, def token_combine(
hidden_states: torch.Tensor, self,
bias: torch.Tensor = None): hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor = None,
):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
hidden_states = self._combine_preprocess(hidden_states) # 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states,
context_metadata)
# Perform expert parallel AlltoAll communication # 2. AllToAll
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
_, permutated_local_input_tokens, handle = async_all_to_all( _, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits, hidden_states,
self.ep_group) context_metadata["input_splits"],
context_metadata["output_splits"],
self.ep_group,
)
handle.wait() handle.wait()
hidden_states.untyped_storage().resize_(0) hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens) # 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens,
# these values are no longer used, so they need to be set to None for memory release. context_metadata)
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
self.topk_weights = None
self.reversed_local_input_permutation_mapping = None
self.reversed_global_input_permutation_mapping = None
self.global_input_tokens_local_experts_indices = None
return output return output
def _dispatch_preprocess(self, hidden_states, topk_ids): def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.size(-1))
tokens_per_expert = self._preprocess(topk_ids) (
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
) = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
@ -610,82 +622,88 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
indices=topk_ids, indices=topk_ids,
num_out_tokens=self.num_out_tokens, num_out_tokens=self.num_out_tokens,
) )
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor: return (
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids, num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts, bins=self.num_experts,
min=0, min=0,
max=self.num_experts) max=self.num_experts)
ep_size = self.ep_size ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel() self.num_out_tokens = topk_ids.numel()
# =================================================== input_splits = (num_local_tokens_per_expert.reshape(
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
ep_size, ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"), self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy()) non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region( num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert, num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts) group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1] 0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None: if num_global_tokens_per_local_expert is None:
raise ValueError( raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.") "num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
axis=0)
global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1: if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None: if num_global_tokens_per_local_expert is None:
raise ValueError( raise ValueError(
"num_global_tokens_per_local_expert must be set before operations." "num_global_tokens_per_local_expert must be set before operations."
) )
self.global_input_tokens_local_experts_indices = torch.repeat_interleave( global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank, self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel()) num_global_tokens_per_local_expert.ravel())
else: else:
# TODO: This full synchronization can be a performance bottleneck.
# A more granular sync (e.g., blocking D2H copies) should be investigated.
torch.npu.synchronize() torch.npu.synchronize()
return num_tokens_per_local_expert return (
num_tokens_per_local_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
)
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None): def _dispatch_postprocess(self, global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices):
# Early return if no local experts or no tokens # Early return if no local experts or no tokens
if self.num_local_experts <= 1: if self.num_local_experts <= 1:
return global_input_tokens, None return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case # Handle quantized case
if self.with_quant: if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \ assert global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess" "global_input_tokens_local_experts_indices must be provided"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze( expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
-1) -1)
active_num = self.global_input_tokens_local_experts_indices.numel() active_num = global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0: if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
# Process with active tokens global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens, global_input_tokens,
expert_idx_2d, expert_idx_2d,
scale=dynamic_scale, scale=dynamic_scale_after_all2all,
active_num=active_num, active_num=active_num,
expert_capacity=0, expert_capacity=0,
expert_num=self.num_local_experts, expert_num=self.num_local_experts,
@ -693,32 +711,34 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_tokens_num_flag=True, expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts], active_expert_range=[0, self.num_local_experts],
quant_mode=-1, quant_mode=-1,
row_idx_type=0) row_idx_type=0,
return global_input_tokens, expanded_scale )
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
# Handle non-quantized case # Non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens, global_input_tokens_local_experts_indices)
self.global_input_tokens_local_experts_indices) return global_input_tokens, None, reversed_global_input_permutation_mapping
return global_input_tokens, None
def _combine_preprocess(self, hidden_states): def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input # Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1: if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[
"reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute( hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping) hidden_states, rev_global)
return hidden_states return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens): def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output # Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute( output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens, permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to( sorted_indices=context_metadata[
torch.int32), "reversed_local_input_permutation_mapping"].to(torch.int32),
probs=self.topk_weights, probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute) restore_shape=self.hidden_shape_before_permute,
)
# Reshape the output tensor
output = output.view(self.hidden_shape) output = output.view(self.hidden_shape)
return output return output