mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[ModelRunner][V1] Optimize V1 attention mask (#442)
### What this PR does / why we need it? Pre-construct a mask matrix to improve the efficiency of attention mask construction during inference. Note that the length of the matrix needs to be carefully balanced: a matrix that is too large will consume excessive VRAM, while a matrix that is too small will require dynamic concatenation during inference, leading to performance degradation. Therefore, an environment variable is added here to dynamically set the size of the pre-constructed mask matrix based on requirements. --------- Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: didongli182 <didongli@huawei.com>
This commit is contained in:
@ -18,6 +18,7 @@
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -53,6 +54,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -210,6 +213,24 @@ class NPUModelRunner:
|
||||
self.max_num_tokens,
|
||||
device="cpu")
|
||||
|
||||
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||
# attention mask construction during inference.
|
||||
# Note that the length of the matrix needs to be carefully balanced: a
|
||||
# matrix that is too large will consume excessive VRAM, while a matrix
|
||||
# that is too small will require dynamic concatenation during inference,
|
||||
# leading to performance degradation.
|
||||
# Therefore, an environment variable is added here to dynamically set
|
||||
# the size of the pre-constructed mask matrix based on requirements.
|
||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||
self.attn_mask_len = min(self.max_model_len, int(mask_len))
|
||||
self.attn_mask_npu = torch.full(
|
||||
(self.attn_mask_len, self.attn_mask_len),
|
||||
NPU_PAGED_ATTENTION_MASK_VALUE,
|
||||
device=self.device,
|
||||
dtype=self.vllm_config.model_config.dtype)
|
||||
self.attn_mask_npu.masked_fill_(
|
||||
self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@ -365,40 +386,52 @@ class NPUModelRunner:
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens,
|
||||
query_lens):
|
||||
# for paged attention
|
||||
atten_mask = np.zeros([0, max_seq_len])
|
||||
for i, context_length in enumerate(seq_lens):
|
||||
q_len = query_lens[i]
|
||||
ones_len = context_length - q_len
|
||||
ones = np.ones((q_len, ones_len), dtype=np.float16)
|
||||
bias_cache = np.tril(
|
||||
np.ones((q_len, max_seq_len - ones_len), dtype=np.float16))
|
||||
bias_cache = np.concatenate((ones, bias_cache), axis=1)
|
||||
mask_value = -10000
|
||||
bias_cache[bias_cache == 0] = mask_value
|
||||
bias_cache[bias_cache == 1] = 0
|
||||
def make_attention_mask(self, seq_lens, query_lens,
|
||||
position) -> torch.Tensor:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self.attn_mask_len:
|
||||
return torch.index_select(self.attn_mask_npu,
|
||||
dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
|
||||
atten_mask = np.concatenate([atten_mask, bias_cache], axis=0)
|
||||
atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device)
|
||||
return atten_mask
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=self.vllm_config.model_config.dtype,
|
||||
device="cpu")
|
||||
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.mask_fill_(
|
||||
right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(self.device, non_blocking=True)
|
||||
|
||||
def _process_reqs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
# check input valid
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Copy the blocks from CPU to NPU.
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
@ -409,7 +442,7 @@ class NPUModelRunner:
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
|
||||
# prepare positions
|
||||
# Prepare positions
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||
@ -444,9 +477,9 @@ class NPUModelRunner:
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
attn_mask = self.make_attention_mask(
|
||||
self.vllm_config.model_config.dtype, self.device,
|
||||
max(seq_lens, default=0), seq_lens, num_scheduled_tokens)
|
||||
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
seq_lens=query_lens,
|
||||
@ -457,7 +490,7 @@ class NPUModelRunner:
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# prepare input_ids
|
||||
# Prepare input_ids
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
@ -468,6 +501,7 @@ class NPUModelRunner:
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
assert self.model is not None
|
||||
|
Reference in New Issue
Block a user