[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao
2025-08-09 20:49:04 -07:00
committed by GitHub
parent 534c45b962
commit 2a84fb422f
3 changed files with 19 additions and 21 deletions

View File

@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
np.cumsum(slice_lens[:-1])])
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
1) // num_slices_per_block * num_slices_per_block
slot_mapping = np.pad(slot_mapping,
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping,
device="cpu",

View File

@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
num_slices_ref, # [1]
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = slices_ref[1, offset_i]
length = slices_ref[2, offset_i]
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
slices_ref[1, offset_i], 0)
length = jax.lax.select(offset_i < num_slices_ref[0],
slices_ref[2, offset_i], 0)
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = slices_ref[0, offset_i]
length = slices_ref[2, offset_i]
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
slices_ref[0, offset_i], 0)
length = jax.lax.select(offset_i < num_slices_ref[0],
slices_ref[2, offset_i], 0)
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
@ -77,7 +82,6 @@ def kv_cache_update(
page_size: int = 32,
num_slices_per_block: int = 8,
):
assert slices.shape[1] % num_slices_per_block == 0
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
@ -93,7 +97,7 @@ def kv_cache_update(
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices]
scalar_prefetches = [slices, num_kv_update_slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,

View File

@ -745,7 +745,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size, self._num_slices_per_kv_cache_update_block)
self.block_size)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
@ -1244,8 +1244,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size,
self._num_slices_per_kv_cache_update_block)
num_tokens, self.max_num_reqs, self.block_size)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
@ -1963,17 +1962,17 @@ def copy_kv_blocks(
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
# NOTE(chengjiyao): let's say R_i is the token num for i-th request,
# so it occupies most 2 + R_i // page_size pages. The total maximum
# possible number of pages needed is sum(2 + R_i // page_size), which
# is <= 2 * max_num_reqs + sum(R_i) // page_size
# = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
padded_num_slices = (
padded_num_slices + num_slices_per_kv_cache_update_block - 1
) // num_slices_per_kv_cache_update_block * \
num_slices_per_kv_cache_update_block
return padded_num_slices