mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it? This PR aims to add padding logic to seq_lens、block_tables when running in full decode scenario. Before this PR, the number of input tokens with padding might exceeds corresponding seq_lens. For example, when running in full decode scenario: ``` input_ids : [1, 3, 0, 0] seq_lens: [2, 1] query_start_loc: [0, 1, 2] ``` Here, `input_ids` is padded by 2 tokens while `seq_lens`/`query_start_loc` are not. The mismatch between `input_ids` and `seq_lens`/`query_start_loc` might cause some potential bugs. This PR would change it into : ``` input_ids : [1, 3, 0, 0] seq_lens: [2, 1, 1, 1] query_start_loc: [0, 1, 2, 3, 4] ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@ -216,6 +216,29 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
||||
seq_lens = torch.cat([
|
||||
seq_lens,
|
||||
torch.ones(padded_num_tokens,
|
||||
dtype=seq_lens.dtype,
|
||||
device=seq_lens.device)
|
||||
])
|
||||
block_table_padding = torch.zeros(
|
||||
(padded_num_tokens, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
||||
query_start_loc_cpu = torch.cat([
|
||||
query_start_loc_cpu,
|
||||
torch.arange(query_start_loc_cpu[-1] + 1,
|
||||
query_start_loc_cpu[-1] + padded_num_tokens,
|
||||
dtype=query_start_loc_cpu.dtype,
|
||||
device=query_start_loc_cpu.device)
|
||||
])
|
||||
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
|
@ -445,6 +445,7 @@ class AscendMLAMetadataBuilder:
|
||||
cos=cos[:num_decode_tokens, ...])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
|
@ -419,6 +419,7 @@ class AscendSFAMetadataBuilder:
|
||||
cos=cos)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
|
@ -64,6 +64,10 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
# num_input_tokens refers to total number of tokens including
|
||||
# padding tokens. It is used to handle some padding operations.
|
||||
num_input_tokens: int = 0
|
||||
|
||||
# NOTE: This is a temporary solution for rotary embedding in MLA
|
||||
cos: torch.Tensor = None
|
||||
sin: torch.Tensor = None
|
||||
|
@ -1477,6 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
num_input_tokens=num_input_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
@ -1523,8 +1524,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model=self.get_model(),
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla or self.use_sparse:
|
||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
|
Reference in New Issue
Block a user