[V1][TPU] Remove unnecessary padding for running on TPU. (#14467)

This commit is contained in:
iefgnoix
2025-03-08 18:56:04 -08:00
committed by GitHub
parent b0d541947a
commit 10f7552789
2 changed files with 6 additions and 18 deletions

View File

@ -12,8 +12,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 256
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend):

View File

@ -23,9 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -78,10 +76,8 @@ class TPUModelRunner:
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = _get_padded_number(
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
NUM_QUERIES_PER_BLOCK)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@ -142,16 +138,8 @@ class TPUModelRunner:
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
# self.input_batch.block_table has a shape of [max_num_reqs,
# max_num_blocks_per_req]. To reduce the number of recompilation,
# we want the block_table.shape[0] to be num_tokens.
# To make the block_table to be compatible with the paged attention
# kernel, we want the block_table[1] to be multiple of
# NUM_KV_PAGES_PER_BLOCK.
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req),
(self.max_num_tokens, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")