mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
This reverts commit 4f937f561d573ae97f953169865cfbf70d0c220b. ### What this PR does / why we need it? This reverts commit 4f937f561d573ae97f953169865cfbf70d0c220b. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@ -137,7 +137,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
sorted_hidden_states = dispatch_output["hidden_states"]
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
@ -145,10 +144,8 @@ def test_token_dispatcher_with_all_gather(
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
@ -218,7 +215,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
group_list = dispatch_output["group_list"]
|
||||
group_list_type = dispatch_output.get("group_list_type", 1)
|
||||
dynamic_scale = dispatch_output["dynamic_scale"]
|
||||
context_metadata = dispatch_output["context_metadata"]
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
@ -229,10 +225,8 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
with_quant=True)
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
combined_output = dispatcher.token_combine(hidden_states=expert_output,
|
||||
bias=None)
|
||||
assert combined_output.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
@ -44,8 +44,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
@ -53,9 +52,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@ -80,11 +77,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, mask = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
@ -100,9 +96,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
@ -118,15 +112,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@ -142,11 +133,10 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
@ -162,9 +152,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
@ -207,9 +195,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
@ -221,9 +209,7 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
return tensor[:3]
|
||||
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@ -277,9 +263,9 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
|
@ -45,7 +45,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@ -59,18 +59,15 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@ -93,8 +90,7 @@ class TestMoECommMethod(TestBase):
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2),
|
||||
torch.tensor([1, 0, 1,
|
||||
0]), None)
|
||||
torch.tensor([1, 0, 1, 0]))
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@ -108,18 +104,15 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@ -142,7 +135,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@ -156,8 +149,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
|
@ -77,10 +77,9 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
@ -124,64 +123,36 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1) # 注意:应为 float,不是 int
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
assist_info_for_combine = torch.arange(10) # mock 值
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": mc2_mask,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
}
|
||||
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
shared_experts = MagicMock()
|
||||
shared_experts.down_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": None,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": torch.randn(10, 128),
|
||||
"swiglu_out_scale": torch.randn(10, 1),
|
||||
}
|
||||
|
||||
self.dispatcher.shared_experts = MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
hidden_states = torch.randn(10, 128)
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
result = self.dispatcher.token_combine(hidden_states,
|
||||
context_metadata)
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
@ -293,26 +264,35 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
@ -438,21 +418,25 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
context_metadata = {
|
||||
"input_splits": [4, 4],
|
||||
"output_splits": [4, 4],
|
||||
"topk_weights": torch.rand(8, 4),
|
||||
"reversed_local_input_permutation_mapping": torch.arange(8),
|
||||
"reversed_global_input_permutation_mapping": torch.arange(16),
|
||||
}
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
|
||||
8)
|
||||
self.dispatcher.topk_weights = torch.rand(8, 4)
|
||||
self.dispatcher.input_splits = [4, 4]
|
||||
self.dispatcher.output_splits = [4, 4]
|
||||
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
|
||||
16)
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
|
||||
[[2, 2], [2, 2]], dtype=torch.int64)
|
||||
|
||||
expert_output = torch.randn(16, 16)
|
||||
output = self.dispatcher.token_combine(expert_output)
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
|
@ -301,7 +301,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
|
||||
forward_context = get_forward_context()
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
@ -329,8 +329,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_experts=None,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask)
|
||||
global_redundant_expert_num=self.global_redundant_expert_num)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
@ -341,8 +340,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
reduce_results=self.reduce_results)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -50,15 +49,12 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
@ -78,14 +74,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
- optional context metadata (e.g., saved split_hidden_states for finalization)
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
@ -103,102 +96,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
assert context_metadata is not None
|
||||
|
||||
split_hidden_states = context_metadata["split_hidden_states"]
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
|
||||
All2All and share the same finalize method.
|
||||
MoE communication strategy using MC2 (Memory-Centric Communication).
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
@ -216,15 +116,12 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
@ -235,11 +132,10 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
|
||||
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
split_hidden_states = None
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
@ -269,10 +165,124 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
self.split_hidden_states = split_hidden_states # Save for finalize
|
||||
|
||||
context_metadata = {"split_hidden_states": split_hidden_states}
|
||||
return hidden_states, router_logits, mc2_mask
|
||||
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
|
||||
2. Unpad to original token count if padding was applied.
|
||||
3. Return tensor with shape [original_num_tokens, hidden_size].
|
||||
|
||||
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
# All-gather across TP group
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
# Unpad if necessary
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
"""Restore original TP configuration (same as MC2)."""
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
2. If TP > 1, split along token dim and select current TP rank's slice.
|
||||
3. Save splits for later all-gather in finalize.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if not (self.replace_allreduce or self.enable_shared_expert_dp):
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
2. Unpad to original token count.
|
||||
3. Return [original_num_tokens, hidden_size] tensor.
|
||||
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
"""
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
|
||||
# If the cache is not cleared after `self.split_hidden_states` is created,
|
||||
# it can lead to the memory explosion in eager mode.
|
||||
del self.split_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
@ -297,15 +307,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
@ -324,24 +331,21 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
@ -349,7 +353,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
@ -373,12 +377,11 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
@ -469,22 +472,19 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
@ -499,12 +499,10 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
router_logits = self._naive_multicast(
|
||||
router_logits, self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert:
|
||||
|
@ -57,31 +57,28 @@ class MoECommMethod(ABC):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
self.mc2_mask = None
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results, context_metadata)
|
||||
hidden_states, reduce_results)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
@ -111,8 +108,7 @@ class MoECommMethod(ABC):
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None):
|
||||
dynamic_eplb: bool = False):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
@ -131,12 +127,12 @@ class MoECommMethod(ABC):
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=mc2_mask,
|
||||
mc2_mask=self.mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
@ -156,7 +152,7 @@ class MoECommMethod(ABC):
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=context_metadata)
|
||||
hidden_states=mlp_output)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (final_hidden_states, group_list_type, expert_tokens)
|
||||
|
@ -75,7 +75,6 @@ class MoETokenDispatcher(ABC):
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
@ -103,7 +102,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.shared_act = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
self.expand_scales = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
@ -111,7 +119,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
mc2_mask: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
):
|
||||
if self.with_quant:
|
||||
@ -148,7 +155,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
})
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
@ -159,121 +166,99 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
self.shared_experts = shared_experts
|
||||
self.mc2_mask = mc2_mask
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
mc2_mask,
|
||||
global_redundant_expert_num)
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, _, expand_scales = output[0:7]
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
|
||||
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
|
||||
|
||||
# Handle shared experts (store intermediate results in local vars, not self)
|
||||
shared_act = None
|
||||
swiglu_out_scale = None
|
||||
if with_quant:
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(shared_gate_up, shared_dequant_scale))
|
||||
shared_act, swiglu_out_scale = shared_act_out[
|
||||
0], shared_act_out[1]
|
||||
self.shared_act, self.swiglu_out_scale = \
|
||||
shared_act_out[0], shared_act_out[1]
|
||||
|
||||
else:
|
||||
if shared_experts is not None:
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"mc2_mask": mc2_mask,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"shared_experts": shared_experts,
|
||||
"shared_act": shared_act,
|
||||
"swiglu_out_scale": swiglu_out_scale,
|
||||
"expand_scales": expand_scales
|
||||
}
|
||||
|
||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
group_list_type = 0
|
||||
return {
|
||||
"group_list_type": 0,
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict):
|
||||
expert_map = context_metadata["expert_map"]
|
||||
topk_ids = context_metadata["topk_ids"]
|
||||
topk_weights = context_metadata["topk_weights"]
|
||||
ep_recv_counts = context_metadata["ep_recv_counts"]
|
||||
assist_info_for_combine = context_metadata["assist_info_for_combine"]
|
||||
mc2_mask = context_metadata["mc2_mask"]
|
||||
expand_scales = context_metadata["expand_scales"]
|
||||
|
||||
assert expert_map is not None
|
||||
moe_expert_num = len(expert_map)
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
|
||||
assert self.expert_map is not None
|
||||
assert self.topk_weights is not None
|
||||
assert self.topk_ids is not None
|
||||
assert self.output is not None
|
||||
moe_expert_num = len(self.expert_map)
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
"expand_x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
else:
|
||||
tp_recv_counts = ep_recv_counts
|
||||
|
||||
tp_recv_counts = self.output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": expand_scales,
|
||||
"expand_scales": self.expand_scales,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs["assist_info_for_combine"] = assist_info_for_combine
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
self.assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs["expand_idx"] = assist_info_for_combine
|
||||
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": self.assist_info_for_combine,
|
||||
})
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
@ -281,40 +266,45 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
|
||||
if self.a3_need_extra_args and self.enable_dispatch_v2:
|
||||
stage3_kwargs["x_active_mask"] = mc2_mask
|
||||
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
|
||||
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
self.expand_scales = None
|
||||
|
||||
# Handle shared experts from metadata
|
||||
shared_experts = context_metadata["shared_experts"]
|
||||
if shared_experts is None:
|
||||
return combined_output
|
||||
|
||||
shared_act = context_metadata["shared_act"]
|
||||
if self.with_quant:
|
||||
swiglu_out_scale = context_metadata["swiglu_out_scale"]
|
||||
shared_hidden_states, _ = shared_experts.down_proj(
|
||||
(shared_act, swiglu_out_scale))
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
|
||||
return combined_output, shared_hidden_states
|
||||
if self.with_quant:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
(self.shared_act, self.swiglu_out_scale))
|
||||
else:
|
||||
shared_hidden_states, _ = self.shared_experts.down_proj(
|
||||
self.shared_act)
|
||||
self.shared_act = None
|
||||
self.shared_experts = None
|
||||
self.swiglu_out_scale = None
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
@ -324,7 +314,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
self.num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.sorted_weights = None
|
||||
self.expanded_row_idx = None
|
||||
self.sorted_token_indices = None
|
||||
self.original_shape = None
|
||||
self.mask = None
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
@ -344,6 +341,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.expert_map = expert_map
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
@ -357,7 +357,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map)
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
topk_weights = topk_weights * mask
|
||||
self.topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
@ -366,7 +366,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
@ -379,31 +379,29 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
))
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
context_metadata = {
|
||||
"topk_weights": topk_weights,
|
||||
"expanded_row_idx": expanded_row_idx
|
||||
}
|
||||
return {
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": expert_tokens,
|
||||
"dynamic_scale": pertoken_scale if self.with_quant else None,
|
||||
"context_metadata": context_metadata
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
assert self.original_shape is not None
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
|
||||
probs=context_metadata["topk_weights"])
|
||||
sorted_indices=torch.abs(self.expanded_row_idx),
|
||||
probs=self.topk_weights)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.expanded_row_idx = None
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -452,12 +450,11 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": sorted_hidden_states,
|
||||
"group_list": group_list,
|
||||
"topk_scales": topk_scales
|
||||
"topk_scales": topk_scales,
|
||||
}
|
||||
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None):
|
||||
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
|
||||
torch.int32)
|
||||
@ -481,8 +478,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
|
||||
self.hidden_shape = None
|
||||
self.topk_weights = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.tokens_per_expert = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
@ -504,116 +512,96 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
):
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
(
|
||||
permutated_local_input_tokens,
|
||||
reversed_local_input_permutation_mapping,
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
) = self._dispatch_preprocess(hidden_states, topk_ids)
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
|
||||
hidden_states, topk_ids)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
permutated_local_input_tokens)
|
||||
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale, output_splits, input_splits, self.ep_group)
|
||||
dynamic_scale,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute2_ep_all_to_all_handle.wait()
|
||||
dynamic_scale.untyped_storage().resize_(0)
|
||||
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens, output_splits, input_splits,
|
||||
self.ep_group)
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
self.ep_group,
|
||||
)
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
# Postprocess
|
||||
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all,
|
||||
global_input_tokens_local_experts_indices)
|
||||
|
||||
context_metadata = {
|
||||
"input_splits":
|
||||
input_splits,
|
||||
"output_splits":
|
||||
output_splits,
|
||||
"topk_weights":
|
||||
topk_weights,
|
||||
"reversed_local_input_permutation_mapping":
|
||||
reversed_local_input_permutation_mapping,
|
||||
"reversed_global_input_permutation_mapping":
|
||||
reversed_global_input_permutation_mapping
|
||||
}
|
||||
|
||||
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all)
|
||||
return {
|
||||
"hidden_states": global_input_tokens,
|
||||
"group_list": tokens_per_expert,
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": dynamic_scale_final,
|
||||
"context_metadata": context_metadata,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
"group_list_type": 1
|
||||
}
|
||||
|
||||
def token_combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor = None,
|
||||
):
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
# 1. Preprocess using metadata
|
||||
hidden_states = self._combine_preprocess(hidden_states,
|
||||
context_metadata)
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
# 2. AllToAll
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states,
|
||||
context_metadata["input_splits"],
|
||||
context_metadata["output_splits"],
|
||||
self.ep_group,
|
||||
)
|
||||
hidden_states, self.input_splits, self.output_splits,
|
||||
self.ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
# 3. Postprocess using metadata
|
||||
output = self._combine_postprocess(permutated_local_input_tokens,
|
||||
context_metadata)
|
||||
output = self._combine_postprocess(permutated_local_input_tokens)
|
||||
|
||||
# these values are no longer used, so they need to be set to None for memory release.
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.topk_weights = None
|
||||
self.reversed_local_input_permutation_mapping = None
|
||||
self.reversed_global_input_permutation_mapping = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
return output
|
||||
|
||||
def _dispatch_preprocess(self, hidden_states, topk_ids):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
(
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
) = self._preprocess(topk_ids)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self._preprocess(topk_ids)
|
||||
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
@ -622,88 +610,82 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
indices=topk_ids,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
return (
|
||||
permutated_local_input_tokens,
|
||||
reversed_local_input_permutation_mapping,
|
||||
tokens_per_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
)
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor):
|
||||
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
|
||||
input_splits = (num_local_tokens_per_expert.reshape(
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size,
|
||||
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
|
||||
non_blocking=True).numpy())
|
||||
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if num_global_tokens_per_local_expert is None:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum.")
|
||||
|
||||
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
|
||||
global_input_tokens_local_experts_indices = None
|
||||
if self.num_local_experts > 1:
|
||||
if num_global_tokens_per_local_expert is None:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
num_global_tokens_per_local_expert.ravel())
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
else:
|
||||
# TODO: This full synchronization can be a performance bottleneck.
|
||||
# A more granular sync (e.g., blocking D2H copies) should be investigated.
|
||||
torch.npu.synchronize()
|
||||
|
||||
return (
|
||||
num_tokens_per_local_expert,
|
||||
input_splits,
|
||||
output_splits,
|
||||
num_global_tokens_per_local_expert,
|
||||
global_input_tokens_local_experts_indices,
|
||||
)
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def _dispatch_postprocess(self, global_input_tokens,
|
||||
dynamic_scale_after_all2all,
|
||||
global_input_tokens_local_experts_indices):
|
||||
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, dynamic_scale_after_all2all, None
|
||||
return global_input_tokens, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
assert global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be provided"
|
||||
expert_idx_2d = global_input_tokens_local_experts_indices.unsqueeze(
|
||||
assert self.global_input_tokens_local_experts_indices is not None, \
|
||||
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
|
||||
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
|
||||
-1)
|
||||
active_num = global_input_tokens_local_experts_indices.numel()
|
||||
active_num = self.global_input_tokens_local_experts_indices.numel()
|
||||
|
||||
# Handle case with no active tokens
|
||||
if active_num <= 0:
|
||||
reversed_global_input_permutation_mapping = global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
|
||||
return global_input_tokens, dynamic_scale
|
||||
|
||||
global_input_tokens, reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
# Process with active tokens
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
global_input_tokens,
|
||||
expert_idx_2d,
|
||||
scale=dynamic_scale_after_all2all,
|
||||
scale=dynamic_scale,
|
||||
active_num=active_num,
|
||||
expert_capacity=0,
|
||||
expert_num=self.num_local_experts,
|
||||
@ -711,34 +693,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[0, self.num_local_experts],
|
||||
quant_mode=-1,
|
||||
row_idx_type=0,
|
||||
)
|
||||
return global_input_tokens, expanded_scale, reversed_global_input_permutation_mapping
|
||||
row_idx_type=0)
|
||||
return global_input_tokens, expanded_scale
|
||||
|
||||
# Non-quantized case
|
||||
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens, global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None, reversed_global_input_permutation_mapping
|
||||
# Handle non-quantized case
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
return global_input_tokens, None
|
||||
|
||||
def _combine_preprocess(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_preprocess(self, hidden_states):
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
rev_global = context_metadata[
|
||||
"reversed_global_input_permutation_mapping"]
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states, rev_global)
|
||||
hidden_states, self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
|
||||
context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_postprocess(self, permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=context_metadata[
|
||||
"reversed_local_input_permutation_mapping"].to(torch.int32),
|
||||
probs=context_metadata["topk_weights"],
|
||||
restore_shape=self.hidden_shape_before_permute,
|
||||
)
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.to(
|
||||
torch.int32),
|
||||
probs=self.topk_weights,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
|
Reference in New Issue
Block a user