diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 5e9d9c373..4735a5f15 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) - row_idx = (torch.arange( - 0, - m * topk, - device=device, - dtype=torch.int32, - ).view(topk, -1).permute(1, 0).contiguous()) dispatcher_kwargs = { "num_experts": e, @@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather( hidden_states=a, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) @@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) - row_idx = (torch.arange( - 0, - m * topk, - device=device, - dtype=torch.int32, - ).view(topk, -1).permute(1, 0).contiguous()) dispatcher_kwargs = { "num_experts": e, @@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant( hidden_states=a, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=True) @@ -297,7 +283,7 @@ def test_select_experts( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=topk, @@ -318,7 +304,6 @@ def test_select_experts( assert topk_weights.shape == (m, topk) assert topk_ids.shape == (m, topk) assert topk_ids.dtype == torch.int32 - assert row_idx.shape == (m, topk) gc.collect() torch.npu.empty_cache() diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index eba948f17..685fbfcda 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -263,7 +263,7 @@ class TestExpertsSelector: x = torch.randn(8, 2) router_logits = torch.randn(8, 2) - topk_weights, topk_ids, _ = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=2, diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 65085cee1..3826a19c1 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -204,7 +204,6 @@ class TestMoECommMethod(TestBase): topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) - row_idx = torch.arange(4) # Make sure tensors are contiguous and have correct strides hidden_states = hidden_states.contiguous() @@ -216,7 +215,6 @@ class TestMoECommMethod(TestBase): w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, activation="silu") # Verify result shape diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index aed2b7dfc..87f384fad 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -58,7 +58,6 @@ class TestTokenDispatcherWithMC2(TestBase): kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} self.dispatcher = TokenDispatcherWithMC2(**kwargs) - self.row_idx = torch.arange(10, dtype=torch.int32) def tearDown(self): self.mc2_group_patch.stop() @@ -96,7 +95,7 @@ class TestTokenDispatcherWithMC2(TestBase): (None, None)) as mock_dispatch: output = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, expert_map) + expert_map) mock_dispatch.assert_called_once() self.assertEqual(output["group_list_type"], 0) # group_list_type == 0 @@ -117,7 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase): self.dispatcher.token_dispatch(self.hidden_states, self.topk_weights, torch.randint(0, 8, (10, 1)), - self.row_idx, torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7]), shared_experts=self.shared_experts) @@ -181,7 +179,6 @@ class TestTokenDispatcherWithAllGather(TestBase): torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx torch.tensor([0, 1, 0, 1, 0, 1])) - self.row_idx = torch.arange(10, dtype=torch.int32) self.patcher_npu_moe_token_unpermute = patch( 'torch_npu.npu_moe_token_unpermute') self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start( @@ -198,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, self.row_idx, None) + topk_ids, None) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_v2.assert_called_once() @@ -213,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, self.row_idx, None) + topk_ids, None) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_v2.assert_called_once() @@ -237,7 +234,7 @@ class TestTokenDispatcherWithAllGather(TestBase): results = self.dispatcher_quant.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, None) + None) self.assertEqual(results["group_list_type"], 1) @@ -258,7 +255,6 @@ class TestTokenDispatcherWithAllGather(TestBase): results = self.dispatcher_quant.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, None, with_quant=True) @@ -401,7 +397,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase): num_experts=4, num_local_experts=2, with_quant=False) - self.row_idx = torch.arange(10, dtype=torch.int32) def test_token_dispatch(self): hidden_states = torch.randn(8, 16) @@ -416,7 +411,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map) self.assertIsNotNone(result["hidden_states"]) @@ -463,7 +457,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, with_quant=True) @@ -492,7 +485,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, with_quant=True) @@ -515,7 +507,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, log2phy=log2phy) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 3bccb1ede..b88e78f7d 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -777,12 +777,12 @@ class TestSelectExperts(TestBase): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="softmax") + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -790,12 +790,12 @@ class TestSelectExperts(TestBase): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -818,13 +818,13 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -838,7 +838,7 @@ class TestSelectExperts(TestBase): self.num_experts) e_score_correction_bias = torch.randn(self.num_experts) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -861,7 +861,7 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.int32)) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -888,7 +888,7 @@ class TestSelectExperts(TestBase): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -914,7 +914,7 @@ class TestSelectExperts(TestBase): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index b82c9314e..90131878a 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -110,7 +110,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): shared_experts: Optional[Any] = None, **kwargs) -> torch.Tensor: - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -138,7 +138,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, global_num_experts=global_num_experts, expert_map=expert_map, shared_experts=shared_experts, diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index 623611328..7beda17fe 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -21,17 +21,6 @@ import torch_npu from vllm.forward_context import get_forward_context -def return_row_idx(hidden_states, top_k): - num_tokens = hidden_states.shape[0] - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) - return row_idx - - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -71,7 +60,7 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( + topk_weights, topk_ids = _select_experts_with_fusion_ops( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, @@ -99,9 +88,7 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) - if row_idx is None: - row_idx = return_row_idx(hidden_states, top_k) - return topk_weights, topk_ids, row_idx + return topk_weights, topk_ids def _native_grouped_topk( @@ -187,7 +174,7 @@ def _select_experts_with_fusion_ops( routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids, row_idx = None, None, None + topk_weights, topk_ids = None, None # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern is_deepseek_v3_r1 = global_num_experts == 256 if is_deepseek_v3_r1: @@ -205,14 +192,13 @@ def _select_experts_with_fusion_ops( # y2_flag=False, # old api; should the third output be output routed_scaling_factor=1, eps=float(1e-20)) - row_idx = return_row_idx(hidden_states, top_k) if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( x=router_logits, finished=None, k=top_k) topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids, row_idx + return topk_weights, topk_ids def _native_select_experts( diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index 6c76aa124..d1d0c1aa0 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -88,7 +88,6 @@ class MoECommMethod(ABC): w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, @@ -122,7 +121,6 @@ class MoECommMethod(ABC): hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 6cb97c33e..3dd799a42 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -61,7 +61,6 @@ class MoETokenDispatcher(ABC): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -171,7 +170,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -330,7 +328,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -422,7 +419,6 @@ class TokenDispatcherWithMoge(MoETokenDispatcher): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -520,7 +516,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 30f081141..4f4dbb048 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -265,7 +265,7 @@ class AscendW4A8DynamicFusedMoEMethod: 1] == global_num_experts, "Number of global experts mismatch" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -297,7 +297,6 @@ class AscendW4A8DynamicFusedMoEMethod: w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index df9c3b272..978826e5c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -205,7 +205,7 @@ class AscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -232,7 +232,6 @@ class AscendW8A8DynamicFusedMoEMethod: w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int8_w8a8=True, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -252,7 +251,6 @@ class AscendW8A8DynamicFusedMoEMethod: w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy,