mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
np.cumsum(slice_lens[:-1])])
|
||||
slot_mapping = np.stack(
|
||||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
|
||||
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
|
||||
1) // num_slices_per_block * num_slices_per_block
|
||||
slot_mapping = np.pad(slot_mapping,
|
||||
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
|
||||
constant_values=0)
|
||||
slot_mapping = np.transpose(slot_mapping)
|
||||
slot_mapping_cpu = torch.tensor(slot_mapping,
|
||||
device="cpu",
|
||||
|
@ -14,6 +14,7 @@ def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
num_slices_ref, # [1]
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
@ -32,8 +33,10 @@ def _kv_cache_update_kernel(
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = slices_ref[1, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[1, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
@ -49,8 +52,10 @@ def _kv_cache_update_kernel(
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = slices_ref[0, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[0, offset_i], 0)
|
||||
length = jax.lax.select(offset_i < num_slices_ref[0],
|
||||
slices_ref[2, offset_i], 0)
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
@ -77,7 +82,6 @@ def kv_cache_update(
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
assert slices.shape[1] % num_slices_per_block == 0
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
@ -93,7 +97,7 @@ def kv_cache_update(
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices]
|
||||
scalar_prefetches = [slices, num_kv_update_slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
|
@ -745,7 +745,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_kv_update_slices = slot_mapping_metadata.shape[0]
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||
self.block_size, self._num_slices_per_kv_cache_update_block)
|
||||
self.block_size)
|
||||
slot_mapping_metadata = np.pad(
|
||||
slot_mapping_metadata,
|
||||
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
|
||||
@ -1244,8 +1244,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32).to(self.device)
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
num_tokens, self.max_num_reqs, self.block_size,
|
||||
self._num_slices_per_kv_cache_update_block)
|
||||
num_tokens, self.max_num_reqs, self.block_size)
|
||||
num_kv_update_slices = torch.tensor([padded_num_slices],
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||
@ -1963,17 +1962,17 @@ def copy_kv_blocks(
|
||||
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def _get_padded_num_kv_cache_update_slices(
|
||||
num_tokens: int, max_num_reqs: int, page_size: int,
|
||||
num_slices_per_kv_cache_update_block: int) -> int:
|
||||
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
||||
page_size: int) -> int:
|
||||
"""Calculates the padded number of KV cache update slices to avoid
|
||||
recompilation."""
|
||||
# NOTE(chengjiyao): let's say R_i is the token num for i-th request,
|
||||
# so it occupies most 2 + R_i // page_size pages. The total maximum
|
||||
# possible number of pages needed is sum(2 + R_i // page_size), which
|
||||
# is <= 2 * max_num_reqs + sum(R_i) // page_size
|
||||
# = 2 * max_num_reqs + num_tokens // page_size
|
||||
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
|
||||
padded_num_slices = min(padded_num_slices, num_tokens)
|
||||
padded_num_slices = (
|
||||
padded_num_slices + num_slices_per_kv_cache_update_block - 1
|
||||
) // num_slices_per_kv_cache_update_block * \
|
||||
num_slices_per_kv_cache_update_block
|
||||
return padded_num_slices
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user