[ModelRunner][V1] Optimize V1 attention mask (#442)

### What this PR does / why we need it?
Pre-construct a mask matrix to improve the efficiency of attention mask
construction during inference.

Note that the length of the matrix needs to be carefully balanced: a
matrix that is too large will consume excessive VRAM, while a matrix
that is too small will require dynamic concatenation during inference,
leading to performance degradation.

Therefore, an environment variable is added here to dynamically set the
size of the pre-constructed mask matrix based on requirements.

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: didongli182 <didongli@huawei.com>
This commit is contained in:
Shanshan Shen
2025-04-02 10:33:53 +08:00
committed by GitHub
parent 94bf9c379e
commit 14d9a64047

View File

@ -18,6 +18,7 @@
#
import gc
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
@ -53,6 +54,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
logger = init_logger(__name__)
@ -210,6 +213,24 @@ class NPUModelRunner:
self.max_num_tokens,
device="cpu")
# NOTE: Pre-construct a mask matrix to improve the efficiency of
# attention mask construction during inference.
# Note that the length of the matrix needs to be carefully balanced: a
# matrix that is too large will consume excessive VRAM, while a matrix
# that is too small will require dynamic concatenation during inference,
# leading to performance degradation.
# Therefore, an environment variable is added here to dynamically set
# the size of the pre-constructed mask matrix based on requirements.
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
self.attn_mask_len = min(self.max_model_len, int(mask_len))
self.attn_mask_npu = torch.full(
(self.attn_mask_len, self.attn_mask_len),
NPU_PAGED_ATTENTION_MASK_VALUE,
device=self.device,
dtype=self.vllm_config.model_config.dtype)
self.attn_mask_npu.masked_fill_(
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@ -365,40 +386,52 @@ class NPUModelRunner:
def get_model(self) -> nn.Module:
return self.model
@staticmethod
def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens,
query_lens):
# for paged attention
atten_mask = np.zeros([0, max_seq_len])
for i, context_length in enumerate(seq_lens):
q_len = query_lens[i]
ones_len = context_length - q_len
ones = np.ones((q_len, ones_len), dtype=np.float16)
bias_cache = np.tril(
np.ones((q_len, max_seq_len - ones_len), dtype=np.float16))
bias_cache = np.concatenate((ones, bias_cache), axis=1)
mask_value = -10000
bias_cache[bias_cache == 0] = mask_value
bias_cache[bias_cache == 1] = 0
def make_attention_mask(self, seq_lens, query_lens,
position) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self.attn_mask_len:
return torch.index_select(self.attn_mask_npu,
dim=0,
index=position)[:, :max_seq_len]
atten_mask = np.concatenate([atten_mask, bias_cache], axis=0)
atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device)
return atten_mask
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=self.vllm_config.model_config.dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len
assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.mask_fill_(
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
current_row += q_len
return attn_mask.to(self.device, non_blocking=True)
def _process_reqs(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
# check input valid
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Copy the blocks from CPU to NPU.
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
@ -409,7 +442,7 @@ class NPUModelRunner:
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
# prepare positions
# Prepare positions
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
cu_num_tokens = np.cumsum(num_scheduled_tokens)
@ -444,9 +477,9 @@ class NPUModelRunner:
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True)
attn_mask = self.make_attention_mask(
self.vllm_config.model_config.dtype, self.device,
max(seq_lens, default=0), seq_lens, num_scheduled_tokens)
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions)
attn_metadata = AscendMetadata(
seq_lens=query_lens,
@ -457,7 +490,7 @@ class NPUModelRunner:
attn_mask=attn_mask,
)
# prepare input_ids
# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
@ -468,6 +501,7 @@ class NPUModelRunner:
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:total_num_scheduled_tokens]
# Run forward pass
with set_forward_context(attn_metadata, self.vllm_config):
assert self.model is not None