[BUGFIX] Mtp torchair pd fix (#3506)

### What this PR does / why we need it?

In memory of https://github.com/vllm-project/vllm-ascend/pull/2610 and
#3449 Fix Mtp torchair pd bug.

In the pd Disaggregation scenario, the first token of the inference
after the d node receives the kv follows the eager mode.

Fixes:
Running with MTP torchair graph mode with Prefilling Decoding
Disaggregation , if all requests processed by the D node are requests
just transmitted from the P node, it will break the torchair graph.

Reason: During PD Disaggregation , the P node only transmits the KV
cache and prompt to the D node, not the actual tokens inferred (neither
the main model tokens nor the MTP tokens are transmitted). Therefore,
the D node will treat this request as one without MTP tokens for
inference (seq_len=1).
The community does not have graph mode issues because the community's
attention has a seq_len=1 for each batch during the decode phase.
We have issues because the graph mode pads according to processing 2
tokens per request. When there are some seq_len=1 and some seq_len=2,
padding is done at the end. If all requests received by the D node are
seq_len=1, padding cannot be performed normally according to the
attention's fia operator constraints.

Solution:

The kv consumer uses extra torchair graph padding to avoid breaking FIA
graph constrains (The one this PR implemented).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-10-17 21:57:05 +08:00
committed by GitHub
parent 9547d6f0d9
commit 21769e8f44
2 changed files with 55 additions and 19 deletions

View File

@ -82,6 +82,31 @@ class NPUTorchairModelRunner(NPUModelRunner):
self._check_batch_sizes_consistency()
def _may_pad_kv_consumer_num_seq(self):
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
# self.max_num_reqs here is greater than the actual maximum request number
if self.is_kv_consumer:
FIA_SEQ_LEN_LIMIT = 16
new_max_num_reqs = self.max_num_reqs + math.ceil(
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
(self.max_num_reqs * self.decode_token_per_req) /
(FIA_SEQ_LEN_LIMIT**2))
if self.max_num_reqs < new_max_num_reqs:
logger.warning(
f"max_num_reqs is updated to {new_max_num_reqs}")
self.max_num_reqs = new_max_num_reqs
def _init_mc2_tokens_capacity(self):
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:

View File

@ -349,6 +349,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self._may_pad_kv_consumer_num_seq()
# Persistent batch.
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
@ -459,24 +468,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * 2, 512).
if self.compilation_config.cudagraph_capture_sizes:
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
else:
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
self._init_mc2_tokens_capacity()
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
@ -534,6 +526,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
def _may_pad_kv_consumer_num_seq(self):
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
# exceeding a sequence length limit (16 tokens) in npu_fused_infer_attention_score operation
pass
def _init_mc2_tokens_capacity(self):
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
if self.compilation_config.cudagraph_capture_sizes:
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
else:
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
@ -2656,7 +2667,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream = torch.npu.Stream()
self.update_stream: torch.npu.Stream = torch.npu.Stream()
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,