[TPU] add kv cache update kernel (#19928)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-06-26 10:01:37 -07:00
committed by GitHub
parent b69781f107
commit 04e1642e32
6 changed files with 342 additions and 38 deletions

View File

@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
import torch_xla
import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a test for TPU only")
@pytest.mark.parametrize("page_size", [32, 33])
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
head_dim: int, num_slices_per_block: int):
page_num = 1000
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
dtype=np.int32)
kv_cache_start_indices = np.array([
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
],
dtype=np.int32)
new_kv_cache_indices = np.concatenate(
[np.array([0], dtype=np.int32),
np.cumsum(slice_lens[:-1])])
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
1) // num_slices_per_block * num_slices_per_block
slot_mapping = np.pad(slot_mapping,
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping,
device="cpu",
dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
num_slices_per_block)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
slice_lens):
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
assert torch.allclose(kv_cache_xla.cpu(),
kv_cache_cpu,
atol=1e-4,
rtol=1e-4)

View File

@ -47,7 +47,7 @@ def test_ragged_paged_attention():
key = torch.zeros(num_tokens, num_kv_heads * head_size)
value = torch.zeros(num_tokens, num_kv_heads * head_size)
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8
max_num_blocks_per_req = 8
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
@ -65,6 +65,7 @@ def test_ragged_paged_attention():
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_slices_per_kv_cache_update_block=8,
)
with patch("torch.ops.xla.ragged_paged_attention"

View File

@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
# slice_len)
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
num_slices_per_block = scratch.shape[0]
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = slices_ref[1, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = slices_ref[0, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
@functools.partial(
jax.jit,
static_argnames=["page_size", "num_slices_per_block"],
)
def kv_cache_update(
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
slices: jax.
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
):
assert slices.shape[1] % num_slices_per_block == 0
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
assert head_dim % 128 == 0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)
scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]
kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(slices.shape[1] // num_slices_per_block, ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
)
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]

View File

@ -5,8 +5,12 @@ from dataclasses import dataclass
from typing import Any, Optional
import torch
# Required to register custom ops.
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
@ -107,6 +111,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping)
write_to_kv_cache(
key, value, kv_cache, slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block)
output = torch.ops.xla.ragged_paged_attention(
query,
@ -244,6 +251,7 @@ def write_to_kv_cache(
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
) -> None:
""" Write the key and values to the KV cache.
@ -251,9 +259,9 @@ def write_to_kv_cache(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
@ -262,4 +270,41 @@ def write_to_kv_cache(
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size,
num_slices_per_kv_cache_update_block)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int num_slices_per_block) -> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache

View File

@ -53,12 +53,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
# Block size used for kv cache updating kernel
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
#########################################################
@ -526,6 +525,69 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def _get_slot_mapping_metadata(self, num_reqs,
num_scheduled_tokens_per_req):
"""
Computes metadata for mapping slots to blocks in the key-value (KV)
cache for a batch of requests.
This function determines, for each request in the batch, how the
scheduled tokens are distributed across memory blocks, and generates
metadata needed to map slices of tokens to their corresponding positions
in the KV cache.
Args:
num_reqs (int): Number of requests in the current batch.
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
to be scheduled for each request.
Returns:
np.ndarray: A 2D array of shape (total_block_len, 3), where each row
contains:
- kv_cache_start_index (int): The starting index in the KV cache
for the corresponding slice.
- new_kv_start_index (int): The starting index in the new KV
cache for the corresponding slice.
- slice_len (int): The length of the slice.
"""
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \
num_scheduled_tokens_per_req
local_block_start_idx = slices_start // self.block_size
local_block_end_idx = (slices_end - 1) // self.block_size
no_repeat_req_indices = self.arange_np[:num_reqs]
global_block_start_idx = (
no_repeat_req_indices * self.max_num_blocks_per_req +
local_block_start_idx)
block_lens = local_block_end_idx - local_block_start_idx + 1
global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
global_block_indices = global_block_start_idx + slice_arange
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
total_block_len = np.sum(block_lens)
slot_mapping_slices = np.repeat(np.array([[0, self.block_size]],
dtype=np.int32),
total_block_len,
axis=0)
cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
np.cumsum(block_lens, out=cu_block_lens[1:])
for req_idx in range(num_reqs):
slot_mapping_slices[cu_block_lens[req_idx]][
0] = slices_start[req_idx] % self.block_size
slot_mapping_slices[
cu_block_lens[req_idx + 1] -
1][1] = (slices_end[req_idx] - 1) % self.block_size + 1
slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
np.cumsum(slice_lens, out=cu_slices_lens[1:])
kv_cache_start_indices = slot_mapping_slices[:, 0] + \
(block_numbers * self.block_size)
new_kv_start_indices = cu_slices_lens[:-1]
slot_mapping_metadata = np.stack(
[kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1)
return slot_mapping_metadata
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
start_index: int):
assert scheduler_output.total_num_scheduled_tokens > 0
@ -603,26 +665,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table[0].
slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens_per_req,
@ -645,12 +687,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.input_batch.block_table[0].slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = (
self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
if use_max_model_len:
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
self.max_num_blocks_per_req]
@ -675,6 +711,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.device)
block_tables = block_tables.to(self.device)
slot_mapping_metadata = self._get_slot_mapping_metadata(
num_reqs, num_scheduled_tokens_per_req)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
constant_values=0)
slot_mapping_metadata = np.transpose(slot_mapping_metadata)
slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
device=self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
@ -687,13 +736,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
padded_num_scheduled_tokens_per_req)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
slot_mapping=slot_mapping_metadata,
block_tables=block_tables,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=torch.tensor([num_reqs],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
@ -1119,8 +1170,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
actual_num_reqs = min(num_tokens, num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size)
slot_mapping = torch.zeros((3, padded_num_slices),
dtype=torch.int32).to(self.device)
block_tables = torch.zeros((num_reqs, num_blocks),
dtype=torch.int32).to(self.device)
query_lens = [1] * num_reqs
@ -1138,6 +1191,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
if self.is_multimodal_model:
@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
return padded_num_slices
def replace_set_lora(model):
def _tpu_set_lora(