mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[BUGFIX] Mtp torchair pd fix (#3506)
### What this PR does / why we need it? In memory of https://github.com/vllm-project/vllm-ascend/pull/2610 and #3449 Fix Mtp torchair pd bug. In the pd Disaggregation scenario, the first token of the inference after the d node receives the kv follows the eager mode. Fixes: Running with MTP torchair graph mode with Prefilling Decoding Disaggregation , if all requests processed by the D node are requests just transmitted from the P node, it will break the torchair graph. Reason: During PD Disaggregation , the P node only transmits the KV cache and prompt to the D node, not the actual tokens inferred (neither the main model tokens nor the MTP tokens are transmitted). Therefore, the D node will treat this request as one without MTP tokens for inference (seq_len=1). The community does not have graph mode issues because the community's attention has a seq_len=1 for each batch during the decode phase. We have issues because the graph mode pads according to processing 2 tokens per request. When there are some seq_len=1 and some seq_len=2, padding is done at the end. If all requests received by the D node are seq_len=1, padding cannot be performed normally according to the attention's fia operator constraints. Solution: The kv consumer uses extra torchair graph padding to avoid breaking FIA graph constrains (The one this PR implemented). ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@ -82,6 +82,31 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
|
|
||||||
self._check_batch_sizes_consistency()
|
self._check_batch_sizes_consistency()
|
||||||
|
|
||||||
|
def _may_pad_kv_consumer_num_seq(self):
|
||||||
|
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||||
|
# self.max_num_reqs here is greater than the actual maximum request number
|
||||||
|
if self.is_kv_consumer:
|
||||||
|
FIA_SEQ_LEN_LIMIT = 16
|
||||||
|
new_max_num_reqs = self.max_num_reqs + math.ceil(
|
||||||
|
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
|
||||||
|
(self.max_num_reqs * self.decode_token_per_req) /
|
||||||
|
(FIA_SEQ_LEN_LIMIT**2))
|
||||||
|
if self.max_num_reqs < new_max_num_reqs:
|
||||||
|
logger.warning(
|
||||||
|
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||||
|
self.max_num_reqs = new_max_num_reqs
|
||||||
|
|
||||||
|
def _init_mc2_tokens_capacity(self):
|
||||||
|
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||||
|
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||||
|
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||||
|
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||||
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
# Use integer arithmetic for ceiling division.
|
||||||
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
|
# maintain the same calculation logic as the function _align_graph_size_divisible_by_tp_size()
|
||||||
|
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||||
|
|
||||||
def _sync_metadata_across_dp(
|
def _sync_metadata_across_dp(
|
||||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||||
|
@ -349,6 +349,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||||
self.decode_token_per_req))
|
self.decode_token_per_req))
|
||||||
|
|
||||||
|
# kv role
|
||||||
|
self.is_kv_producer = False
|
||||||
|
self.is_kv_consumer = False
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
|
|
||||||
|
self._may_pad_kv_consumer_num_seq()
|
||||||
|
|
||||||
# Persistent batch.
|
# Persistent batch.
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -459,24 +468,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||||
self.in_profile_run = False
|
self.in_profile_run = False
|
||||||
|
|
||||||
# kv role
|
self._init_mc2_tokens_capacity()
|
||||||
self.is_kv_producer = False
|
|
||||||
self.is_kv_consumer = False
|
|
||||||
if vllm_config.kv_transfer_config is not None:
|
|
||||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
|
||||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
|
||||||
|
|
||||||
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
|
||||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
|
||||||
# the max number of tokens in graph is min(max_num_seqs * 2, 512).
|
|
||||||
if self.compilation_config.cudagraph_capture_sizes:
|
|
||||||
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
|
|
||||||
else:
|
|
||||||
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
|
||||||
tp_size = self.parallel_config.tensor_parallel_size
|
|
||||||
# Use integer arithmetic for ceiling division.
|
|
||||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
|
||||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
|
||||||
self.reserved_mc2_mask = torch.zeros(
|
self.reserved_mc2_mask = torch.zeros(
|
||||||
self.mc2_tokens_capacity,
|
self.mc2_tokens_capacity,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
@ -534,6 +526,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
def _may_pad_kv_consumer_num_seq(self):
|
||||||
|
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
||||||
|
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
|
||||||
|
# exceeding a sequence length limit (16 tokens) in npu_fused_infer_attention_score operation
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_mc2_tokens_capacity(self):
|
||||||
|
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||||
|
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||||
|
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||||
|
if self.compilation_config.cudagraph_capture_sizes:
|
||||||
|
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
|
||||||
|
else:
|
||||||
|
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||||
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
# Use integer arithmetic for ceiling division.
|
||||||
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
|
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||||
|
|
||||||
def _make_buffer(self,
|
def _make_buffer(self,
|
||||||
*size: Union[int, torch.SymInt],
|
*size: Union[int, torch.SymInt],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -2656,7 +2667,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# wrap the model with full graph wrapper if needed.
|
# wrap the model with full graph wrapper if needed.
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
self.update_stream = torch.npu.Stream()
|
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||||
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
||||||
self.model = ACLGraphWrapper(self.model,
|
self.model = ACLGraphWrapper(self.model,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
|
Reference in New Issue
Block a user