mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Add numpy implementation of compute_slot_mapping
(#7377)
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
"""Attention backend utils"""
|
||||
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
@ -13,6 +14,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
# Switch to numpy implementation of compute_slot_mapping
|
||||
# if we have at least this many elements. Could be tuned further.
|
||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
@ -46,6 +51,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
||||
return start_idx
|
||||
|
||||
|
||||
def _compute_slot_mapping_python(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
for i in range(range_start, range_end):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
|
||||
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
||||
block_table: List[int], range_start: int,
|
||||
range_end: int, block_size: int):
|
||||
block_table_array = np.array(block_table)
|
||||
idx = np.arange(range_start, range_end)
|
||||
block_offset = idx % block_size
|
||||
idx //= block_size
|
||||
seq_slot_mapping_array = block_table_array[idx]
|
||||
seq_slot_mapping_array *= block_size
|
||||
seq_slot_mapping_array += block_offset
|
||||
slot_mapping.extend(seq_slot_mapping_array)
|
||||
|
||||
|
||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
seq_id: int, seq_len: int, context_len: int,
|
||||
start_idx: int, block_size: int,
|
||||
@ -67,21 +95,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
padding_mask_len = max(0, start_idx - context_len)
|
||||
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
||||
|
||||
range_start = max(start_idx, context_len)
|
||||
range_end = seq_len
|
||||
numel = range_end - range_start
|
||||
block_table = block_tables[seq_id]
|
||||
|
||||
def add_slot(i):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if start_idx == 0 and (seq_len - context_len) == 1:
|
||||
# Optimization for common-case of decoding next token
|
||||
add_slot(seq_len - 1)
|
||||
# numpy implementation will be faster than python if we have
|
||||
# many elements, otherwise it will be slower.
|
||||
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
||||
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
else:
|
||||
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
|
||||
for i in range(max(start_idx, context_len), seq_len):
|
||||
add_slot(i)
|
||||
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
||||
range_end, block_size)
|
||||
|
||||
|
||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||
|
Reference in New Issue
Block a user