[Core]Append padding logic for Attention (#3256)

### What this PR does / why we need it?

This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:

```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :

```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2025-10-17 21:56:01 +08:00
committed by GitHub
parent b154a8e22c
commit 9547d6f0d9
5 changed files with 30 additions and 2 deletions

View File

@ -216,6 +216,29 @@ class AscendAttentionMetadataBuilder:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
torch.ones(padded_num_tokens,
dtype=seq_lens.dtype,
device=seq_lens.device)
])
block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding], dim=0)
query_start_loc_cpu = torch.cat([
query_start_loc_cpu,
torch.arange(query_start_loc_cpu[-1] + 1,
query_start_loc_cpu[-1] + padded_num_tokens,
dtype=query_start_loc_cpu.dtype,
device=query_start_loc_cpu.device)
])
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)

View File

@ -445,6 +445,7 @@ class AscendMLAMetadataBuilder:
cos=cos[:num_decode_tokens, ...])
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,

View File

@ -419,6 +419,7 @@ class AscendSFAMetadataBuilder:
cos=cos)
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,

View File

@ -64,6 +64,10 @@ class AscendCommonAttentionMetadata:
graph_pad_size: int = -1
# num_input_tokens refers to total number of tokens including
# padding tokens. It is used to handle some padding operations.
num_input_tokens: int = 0
# NOTE: This is a temporary solution for rotary embedding in MLA
cos: torch.Tensor = None
sin: torch.Tensor = None

View File

@ -1477,6 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
num_input_tokens=num_input_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
@ -1523,8 +1524,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model=self.get_model(),
**extra_attn_metadata_args)
if self.vllm_config.model_config.use_mla or self.use_sparse:
attn_metadata_i.num_input_tokens = num_input_tokens
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i