[Feat] Shared expert dp for deepseek and deepseek_mtp (#3495)

### What this PR does / why we need it?
shared expert dp for deepseek and deepseek_mtp, could be combined with
sp to improve performance.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zhaozx-cn <zhaozx2116@163.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
zhaozx-cn
2025-10-17 15:06:37 +08:00
committed by GitHub
parent d9ee491f70
commit bf87606932
9 changed files with 57 additions and 10 deletions

View File

@ -57,6 +57,8 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
lambda x, label: x)
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
@ -182,6 +184,8 @@ class TestCustomDeepSeekMTP(PytestBase):
assert isinstance(mtp, CustomDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
lambda x, label: x)
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]

View File

@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch
import pytest
import torch
@ -53,7 +54,9 @@ class TestAscendRMSNorm(PytestBase):
# Test case for the most common and basic scenario
@pytest.mark.parametrize(
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
def test_forward_oot_basic(self, residual):
@patch("torch.ops.vllm.maybe_chunk_residual")
def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual):
mock_maybe_chunk_residual.side_effect = lambda x, residual: residual
layer = RMSNorm(hidden_size=8, eps=1e-05)
x = torch.randn(4, 8, dtype=torch.float16)
if residual is not None:
@ -117,6 +120,8 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = num_hidden_layers
mock_forward_context.fusion_linear = "gate_up_dense"
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
lambda x, residual: residual)
# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)

View File

@ -1245,7 +1245,8 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[

View File

@ -88,6 +88,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
inputs_embeds, True)
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds),
@ -200,4 +202,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True)
return hidden_states

View File

@ -122,8 +122,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
need_gather_q_kv = get_forward_context().sp_enabled
output_shape = hidden_states.shape
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
need_gather_q_kv = False
if sp_enabled and self.debug_layer_idx < self.layers:
need_gather_q_kv = True
if not sp_enabled or self.debug_layer_idx < self.layers:
output_shape = hidden_states.shape
else:
# used in deepseek mtp layer
output_shape = torch.chunk(hidden_states, self.tp_size,
dim=0)[0].shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,

View File

@ -99,6 +99,7 @@ class AscendRMSNorm(RMSNorm):
import torch_npu
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear,

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
@ -15,6 +16,27 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return residual
if x.size(0) != residual.size(0):
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, ("Currently, this situation only occurs "
"when sp is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(
x: torch.Tensor,
label: bool,
@ -235,6 +257,11 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
return tensor_model_parallel_all_reduce(final_hidden_states)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,

View File

@ -273,7 +273,7 @@ class NPUPlatform(Platform):
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
@ -320,8 +320,6 @@ class NPUPlatform(Platform):
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and not use_sparse:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sparse:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"

View File

@ -82,9 +82,7 @@ class MtpProposer(Proposer):
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
if self.torchair_graph_enabled or (
self.enable_shared_expert_dp
and self.vllm_config.model_config.use_mla):
if self.torchair_graph_enabled:
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else: