mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[TPU] Fix TPU SMEM OOM by Pallas paged attention kernel (#9350)
This commit is contained in:
@ -208,35 +208,54 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache[0].numel() > 0
|
||||
|
||||
query = query.squeeze(dim=1)
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = self.megacore_mode
|
||||
|
||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||
# "xla::paged_attention() Expected a value of type 'str' for
|
||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||
if megacore_mode is not None:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query.squeeze(dim=1),
|
||||
assert attn_metadata.block_tables is not None
|
||||
assert attn_metadata.context_lens is not None
|
||||
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||
# block table in SMEM. Therefore, if the block table is too large,
|
||||
# the kernel compilation will fail. To avoid this, we split the
|
||||
# batch dimension into smaller chunks and run the kernel multiple
|
||||
# times.
|
||||
MAX_SMEM_USAGE = 512 * 1024
|
||||
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||
|
||||
if batch_size <= max_num_seq:
|
||||
output = paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
self.megacore_mode,
|
||||
)
|
||||
else:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query.squeeze(dim=1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
)
|
||||
chunk_size = max_num_seq
|
||||
# Make sure the chunk size is a multiple of 2.
|
||||
chunk_size = chunk_size // 2 * 2
|
||||
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||
|
||||
output = torch.empty_like(query)
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||
# compilation error. Instead, we rely on the slice operation
|
||||
# to handle the out-of-bound case.
|
||||
# chunk_end = min(chunk_end, batch_size)
|
||||
chunk_output = paged_attention(
|
||||
query[chunk_start:chunk_end],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
)
|
||||
output[chunk_start:chunk_end] = chunk_output
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
@ -258,3 +277,43 @@ def write_to_kv_cache(
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
batch_size = query.shape[0]
|
||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = megacore_mode
|
||||
|
||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||
# "xla::paged_attention() Expected a value of type 'str' for
|
||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||
if megacore_mode is not None:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
)
|
||||
else:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
)
|
||||
return output
|
||||
|
@ -123,6 +123,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
)
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
|
||||
smem_size = 512 * 1024
|
||||
block_table_size = 4 * self.block_tables.size
|
||||
if block_table_size >= smem_size:
|
||||
logger.warning(
|
||||
"The max_model_len (%d) is too large. This may degrade the "
|
||||
"performance due to the insufficient smem size. Consider "
|
||||
"setting --max-model-len to a smaller value.",
|
||||
self.model_config.max_model_len)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
|
Reference in New Issue
Block a user