mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Remove unused row_idx in token_dispatcher (#3442)
### What this PR does / why we need it? The `row_idx` parameter is no longer used since PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so remove it across multiple files to remove unnecessary calculations and parameter passing. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
row_idx = (torch.arange(
|
||||
0,
|
||||
m * topk,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
).view(topk, -1).permute(1, 0).contiguous())
|
||||
|
||||
dispatcher_kwargs = {
|
||||
"num_experts": e,
|
||||
@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
@ -297,7 +283,7 @@ def test_select_experts(
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=topk,
|
||||
@ -318,7 +304,6 @@ def test_select_experts(
|
||||
assert topk_weights.shape == (m, topk)
|
||||
assert topk_ids.shape == (m, topk)
|
||||
assert topk_ids.dtype == torch.int32
|
||||
assert row_idx.shape == (m, topk)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
@ -263,7 +263,7 @@ class TestExpertsSelector:
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
|
@ -204,7 +204,6 @@ class TestMoECommMethod(TestBase):
|
||||
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
||||
[0.6, 0.4]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
||||
row_idx = torch.arange(4)
|
||||
|
||||
# Make sure tensors are contiguous and have correct strides
|
||||
hidden_states = hidden_states.contiguous()
|
||||
@ -216,7 +215,6 @@ class TestMoECommMethod(TestBase):
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
activation="silu")
|
||||
|
||||
# Verify result shape
|
||||
|
@ -58,7 +58,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
@ -96,7 +95,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
(None, None)) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, expert_map)
|
||||
expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
0) # group_list_type == 0
|
||||
@ -117,7 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
self.row_idx,
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
@ -181,7 +179,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
@ -198,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
@ -213,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
topk_ids, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
@ -237,7 +234,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
@ -258,7 +255,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.row_idx,
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
@ -401,7 +397,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
@ -416,7 +411,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
@ -463,7 +457,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
@ -492,7 +485,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
@ -515,7 +507,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
|
@ -777,7 +777,7 @@ class TestSelectExperts(TestBase):
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
@ -790,7 +790,7 @@ class TestSelectExperts(TestBase):
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
@ -818,7 +818,7 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
@ -838,7 +838,7 @@ class TestSelectExperts(TestBase):
|
||||
self.num_experts)
|
||||
|
||||
e_score_correction_bias = torch.randn(self.num_experts)
|
||||
weights, ids, _ = select_experts(
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@ -861,7 +861,7 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@ -888,7 +888,7 @@ class TestSelectExperts(TestBase):
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
@ -914,7 +914,7 @@ class TestSelectExperts(TestBase):
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
|
@ -110,7 +110,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@ -138,7 +138,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
|
@ -21,17 +21,6 @@ import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
return row_idx
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
@ -71,7 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@ -99,9 +88,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
if row_idx is None:
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
return topk_weights, topk_ids, row_idx
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
@ -187,7 +174,7 @@ def _select_experts_with_fusion_ops(
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids, row_idx = None, None, None
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
if is_deepseek_v3_r1:
|
||||
@ -205,14 +192,13 @@ def _select_experts_with_fusion_ops(
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids, row_idx
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
|
@ -88,7 +88,6 @@ class MoECommMethod(ABC):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -122,7 +121,6 @@ class MoECommMethod(ABC):
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
|
@ -61,7 +61,6 @@ class MoETokenDispatcher(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
@ -171,7 +170,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
@ -330,7 +328,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
@ -422,7 +419,6 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
@ -520,7 +516,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
row_idx: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
|
@ -265,7 +265,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@ -297,7 +297,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
use_int4_w4a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
|
@ -205,7 +205,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
@ -232,7 +232,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
@ -252,7 +251,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
use_int8_w8a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
|
Reference in New Issue
Block a user