mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][TPU] Remove unnecessary padding for running on TPU. (#14467)
This commit is contained in:
@ -12,8 +12,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
# These are the 2 tunable parameters of the paged attention Pallas kernel.
|
||||
NUM_QUERIES_PER_BLOCK = 16
|
||||
NUM_KV_PAGES_PER_BLOCK = 256
|
||||
NUM_QUERIES_PER_BLOCK = 32
|
||||
NUM_KV_PAGES_PER_BLOCK = 128
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
@ -23,9 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
PallasAttentionBackend,
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -78,10 +76,8 @@ class TPUModelRunner:
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = _get_padded_number(
|
||||
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
|
||||
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
|
||||
NUM_QUERIES_PER_BLOCK)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
@ -142,16 +138,8 @@ class TPUModelRunner:
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
# self.input_batch.block_table has a shape of [max_num_reqs,
|
||||
# max_num_blocks_per_req]. To reduce the number of recompilation,
|
||||
# we want the block_table.shape[0] to be num_tokens.
|
||||
# To make the block_table to be compatible with the paged attention
|
||||
# kernel, we want the block_table[1] to be multiple of
|
||||
# NUM_KV_PAGES_PER_BLOCK.
|
||||
padded_max_num_blocks_per_req = _get_padded_number(
|
||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
device="cpu")
|
||||
|
||||
|
Reference in New Issue
Block a user