mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat] Shared expert dp for deepseek and deepseek_mtp (#3495)
### What this PR does / why we need it? shared expert dp for deepseek and deepseek_mtp, could be combined with sp to improve performance. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zhaozx-cn <zhaozx2116@163.com> Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@ -57,6 +57,8 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
'eh_proj',
|
||||
return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
|
||||
lambda x, label: x)
|
||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||
torch.randn(2, 3, 768))
|
||||
|
||||
@ -182,6 +184,8 @@ class TestCustomDeepSeekMTP(PytestBase):
|
||||
assert isinstance(mtp, CustomDeepSeekMTP)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
||||
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
|
||||
lambda x, label: x)
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -53,7 +54,9 @@ class TestAscendRMSNorm(PytestBase):
|
||||
# Test case for the most common and basic scenario
|
||||
@pytest.mark.parametrize(
|
||||
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
|
||||
def test_forward_oot_basic(self, residual):
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual")
|
||||
def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual):
|
||||
mock_maybe_chunk_residual.side_effect = lambda x, residual: residual
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
if residual is not None:
|
||||
@ -117,6 +120,8 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = num_hidden_layers
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
lambda x, residual: residual)
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
|
@ -1245,7 +1245,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is not None:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[
|
||||
|
@ -88,6 +88,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
inputs_embeds, True)
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
@ -200,4 +202,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True)
|
||||
return hidden_states
|
||||
|
@ -122,8 +122,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
need_gather_q_kv = False
|
||||
if sp_enabled and self.debug_layer_idx < self.layers:
|
||||
need_gather_q_kv = True
|
||||
if not sp_enabled or self.debug_layer_idx < self.layers:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
# used in deepseek mtp layer
|
||||
output_shape = torch.chunk(hidden_states, self.tp_size,
|
||||
dim=0)[0].shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
|
@ -99,6 +99,7 @@ class AscendRMSNorm(RMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear,
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@ -15,6 +16,27 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
residual: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
|
||||
|
||||
return residual
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
@ -235,6 +257,11 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
|
@ -273,7 +273,7 @@ class NPUPlatform(Platform):
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||
@ -320,8 +320,6 @@ class NPUPlatform(Platform):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
if use_mla and not use_sparse:
|
||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||
if use_mla and use_sparse:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
|
@ -82,9 +82,7 @@ class MtpProposer(Proposer):
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
if self.torchair_graph_enabled or (
|
||||
self.enable_shared_expert_dp
|
||||
and self.vllm_config.model_config.use_mla):
|
||||
if self.torchair_graph_enabled:
|
||||
self.model = TorchairDeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user