3 Commits

Author SHA1 Message Date
15b2e5c995 Remove unused row_idx in token_dispatcher (#3442)
### What this PR does / why we need it?
The `row_idx` parameter is no longer used since
PR[#2689](https://github.com/vllm-project/vllm-ascend/pull/2689), so
remove it across multiple files to remove unnecessary calculations and
parameter passing.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: CaranLic <740821011@qq.com>
2025-10-15 09:08:31 +08:00
3642b64afc bugfix for mtp with multistream_moe (#3419)
### What this PR does / why we need it?
when infer deepseek mtp layer with multistream_moe, we should pass a
boolean to evaluate this feature and fix bugs when we are in mtp layer

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
2025-10-15 08:59:58 +08:00
c2c1db78a7 [Bugfix] fix ZeroDivisionError when prefill_tp_size > num_kv_head and fix tp_resharding README (#3437)
### What this PR does / why we need it?
Fix ZeroDivisionError when prefill_tp_size > num_kv_head, in this
situation, num_head_replica can be 0 and used to divide another value,
this PR restricts the minimum value of a to be 1. And this PR fix
tp_resharding README.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
2025-10-15 08:45:44 +08:00
20 changed files with 85 additions and 113 deletions

View File

@ -114,10 +114,10 @@ export VLLM_USE_V1=1
export HCCL_BUFFSIZE=1024
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=0
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export ASCEND_AGGREGATE_ENABLE=1 # enable aggregated transmission
export ASCEND_TRANSPORT_PRINT=0 # print ascend transport logs
export ACL_OP_INIT_MODE=1 # acl op initialization mode to prevent device id acquisition failure
export ASCEND_A3_ENABLE=1 # enable hccs transmission for A3; set to 0 for A2
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
vllm serve /model/Qwen3-235B-A22B-W8A8 \
@ -137,7 +137,6 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "MooncakeLayerwiseConnector",
@ -197,7 +196,6 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "MooncakeLayerwiseConnector",
@ -363,6 +361,10 @@ export VLLM_USE_V1=1
export HCCL_BUFFSIZE=1024
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=0
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
vllm serve /model/Qwen3-235B-A22B-W8A8 \
@ -382,7 +384,6 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "MooncakeConnector",
@ -419,6 +420,10 @@ export VLLM_USE_V1=1
export HCCL_BUFFSIZE=1024
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=0
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
vllm serve /model/Qwen3-235B-A22B-W8A8 \
@ -438,7 +443,6 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \
--max-model-len 32768 \
--max-num-batched-tokens 32768 \
--trust-remote-code \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.9 \
--kv-transfer-config \
'{"kv_connector": "MooncakeConnector",
@ -475,6 +479,10 @@ export VLLM_USE_V1=1
export HCCL_BUFFSIZE=2048
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=0
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
vllm serve /model/Qwen3-235B-A22B-W8A8 \
@ -532,6 +540,10 @@ export VLLM_USE_V1=1
export HCCL_BUFFSIZE=2048
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=10
export ASCEND_AGGREGATE_ENABLE=1
export ASCEND_TRANSPORT_PRINT=0
export ACL_OP_INIT_MODE=1
export ASCEND_A3_ENABLE=1
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH
vllm serve /model/Qwen3-235B-A22B-W8A8 \

View File

@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
row_idx = (torch.arange(
0,
m * topk,
device=device,
dtype=torch.int32,
).view(topk, -1).permute(1, 0).contiguous())
dispatcher_kwargs = {
"num_experts": e,
@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
@ -297,7 +283,7 @@ def test_select_experts(
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
@ -318,7 +304,6 @@ def test_select_experts(
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
assert row_idx.shape == (m, topk)
gc.collect()
torch.npu.empty_cache()

View File

@ -41,6 +41,7 @@ def test_mtp_torchair_correctness(
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
},
"multistream_overlap_shared_expert": "True"
}) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
with VllmRunner(model_name,
@ -60,7 +61,8 @@ def test_mtp_torchair_correctness(
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
}
},
"multistream_overlap_shared_expert": "True"
}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)

View File

@ -79,7 +79,7 @@ class TestKVCacheSendingLayerThreadBasic(unittest.TestCase):
self.p1 = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
new=MagicMock(return_value=SimpleNamespace(
pd_tp_ratio=1, num_head_replica=0, pd_head_ratio=1)))
pd_tp_ratio=1, num_head_replica=1, pd_head_ratio=1)))
self.p2 = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config',
new=MagicMock(return_value=SimpleNamespace(
@ -244,7 +244,7 @@ class TestSendingLayerThread(unittest.TestCase):
self.p1 = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
new=MagicMock(return_value=SimpleNamespace(
pd_tp_ratio=1, num_head_replica=0, pd_head_ratio=1)))
pd_tp_ratio=1, num_head_replica=1, pd_head_ratio=1)))
self.p2 = patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config',
new=MagicMock(return_value=SimpleNamespace(
@ -903,7 +903,7 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
return_value=SimpleNamespace(pd_tp_ratio=1,
num_head_replica=0,
num_head_replica=1,
pd_head_ratio=1),
),
patch(

View File

@ -263,7 +263,7 @@ class TestExpertsSelector:
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,

View File

@ -204,7 +204,6 @@ class TestMoECommMethod(TestBase):
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
[0.6, 0.4]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
row_idx = torch.arange(4)
# Make sure tensors are contiguous and have correct strides
hidden_states = hidden_states.contiguous()
@ -216,7 +215,6 @@ class TestMoECommMethod(TestBase):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
activation="silu")
# Verify result shape

View File

@ -58,7 +58,6 @@ class TestTokenDispatcherWithMC2(TestBase):
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
@ -96,7 +95,7 @@ class TestTokenDispatcherWithMC2(TestBase):
(None, None)) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, expert_map)
expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
0) # group_list_type == 0
@ -117,7 +116,6 @@ class TestTokenDispatcherWithMC2(TestBase):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
@ -181,7 +179,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
torch.tensor([0, 1, 0, 1, 0, 1]))
self.row_idx = torch.arange(10, dtype=torch.int32)
self.patcher_npu_moe_token_unpermute = patch(
'torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
@ -198,7 +195,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
@ -213,7 +210,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
topk_ids, None)
# Verify npu_moe_init_routing is called
self.mock_npu_moe_init_routing_v2.assert_called_once()
@ -237,7 +234,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
None)
self.assertEqual(results["group_list_type"], 1)
@ -258,7 +255,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights,
topk_ids,
self.row_idx,
None,
with_quant=True)
@ -401,7 +397,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
@ -416,7 +411,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
@ -463,7 +457,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
@ -492,7 +485,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
@ -515,7 +507,6 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)

View File

@ -777,12 +777,12 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@ -790,12 +790,12 @@ class TestSelectExperts(TestBase):
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@ -818,13 +818,13 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.long))
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@ -838,7 +838,7 @@ class TestSelectExperts(TestBase):
self.num_experts)
e_score_correction_bias = torch.randn(self.num_experts)
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@ -861,7 +861,7 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.int32))
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@ -888,7 +888,7 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
@ -914,7 +914,7 @@ class TestSelectExperts(TestBase):
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,

View File

@ -17,6 +17,9 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
return_value=1)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
@ -56,6 +59,8 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
@ -65,7 +70,7 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (2, 3, 768)
assert output.shape == (3, 768)
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):

View File

@ -102,7 +102,7 @@ class AscendConfig:
)
self.pd_tp_ratio = 1
self.pd_head_ratio = 1
self.num_head_replica = 0
self.num_head_replica = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {"tp_size": 1})["tp_size"]
@ -115,7 +115,7 @@ class AscendConfig:
# only support Qwen model now
# TODO: use a more robust method to get kv_head_num
num_kv_head = vllm_config.model_config.hf_config.num_key_value_heads
self.num_head_replica = prefill_tp_size // num_kv_head
self.num_head_replica = prefill_tp_size // num_kv_head if prefill_tp_size >= num_kv_head else 1
prefill_tp_size = min(prefill_tp_size, num_kv_head)
decode_tp_size = min(decode_tp_size, num_kv_head)
self.pd_head_ratio = prefill_tp_size // decode_tp_size

View File

@ -103,8 +103,6 @@ def split_decodes_and_prefills(
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()

View File

@ -360,7 +360,7 @@ class SendingLayerThread(threading.Thread):
remote_kv_base_addrs = req_meta.kv_caches_base_addr
remote_block_ids = req_meta.block_ids
if self.num_head_replica >= 1 and self.tp_rank % self.num_head_replica != 0:
if self.tp_rank % self.num_head_replica != 0:
pass
elif self.pd_head_ratio == 1:
layer_local_kv_base_addr = [

View File

@ -110,7 +110,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@ -138,7 +138,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,

View File

@ -21,17 +21,6 @@ import torch_npu
from vllm.forward_context import get_forward_context
def return_row_idx(hidden_states, top_k):
num_tokens = hidden_states.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
return row_idx
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
@ -71,7 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
@ -99,9 +88,7 @@ def select_experts(hidden_states: torch.Tensor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if row_idx is None:
row_idx = return_row_idx(hidden_states, top_k)
return topk_weights, topk_ids, row_idx
return topk_weights, topk_ids
def _native_grouped_topk(
@ -187,7 +174,7 @@ def _select_experts_with_fusion_ops(
routed_scaling_factor=1.0,
global_num_experts: int = -1):
topk_weights, topk_ids, row_idx = None, None, None
topk_weights, topk_ids = None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
@ -205,14 +192,13 @@ def _select_experts_with_fusion_ops(
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
row_idx = return_row_idx(hidden_states, top_k)
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids, row_idx
return topk_weights, topk_ids
def _native_select_experts(

View File

@ -88,7 +88,6 @@ class MoECommMethod(ABC):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
@ -122,7 +121,6 @@ class MoECommMethod(ABC):
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,

View File

@ -61,7 +61,6 @@ class MoETokenDispatcher(ABC):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
@ -171,7 +170,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
@ -330,7 +328,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
@ -422,7 +419,6 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
@ -520,7 +516,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,

View File

@ -265,7 +265,7 @@ class AscendW4A8DynamicFusedMoEMethod:
1] == global_num_experts, "Number of global experts mismatch"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@ -297,7 +297,6 @@ class AscendW4A8DynamicFusedMoEMethod:
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,

View File

@ -205,7 +205,7 @@ class AscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids, row_idx = select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@ -232,7 +232,6 @@ class AscendW8A8DynamicFusedMoEMethod:
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
@ -252,7 +251,6 @@ class AscendW8A8DynamicFusedMoEMethod:
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
use_int8_w8a8=True,
expert_map=expert_map,
log2phy=log2phy,

View File

@ -24,6 +24,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -66,6 +67,7 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
) -> None:
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
@ -100,11 +102,15 @@ class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
hidden_states, residual = self.mtp_block(
positions=positions,
hidden_states=hidden_states,
residual=None,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
replace_allreduce=replace_allreduce)
hidden_states = residual + hidden_states
return hidden_states

View File

@ -975,7 +975,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
@ -1034,7 +1034,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1:
if mla_moe_communication and self.layer_idx >= self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)