mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] kv cache update kernel supports dynamic grid (#20235)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
||||
dtype=np.int32)
|
||||
num_kv_update_slices = len(slice_lens)
|
||||
kv_cache_start_indices = np.array([
|
||||
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
||||
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
||||
@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
|
||||
device=torch_xla.device(),
|
||||
dtype=torch.int32)
|
||||
torch_xla.sync()
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
|
||||
num_slices_per_block)
|
||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
|
||||
page_size, num_slices_per_block)
|
||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||
torch_xla.sync()
|
||||
|
||||
|
@ -7,11 +7,13 @@ import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
|
||||
# slice_len)
|
||||
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||
# new_kv_start, slice_len)
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
@ -70,6 +72,7 @@ def kv_cache_update(
|
||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
kv_cache: jax.
|
||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
num_kv_update_slices: jax.Array, # [1]
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
@ -107,7 +110,7 @@ def kv_cache_update(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(slices.shape[1] // num_slices_per_block, ),
|
||||
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
|
@ -111,6 +111,7 @@ class PallasMetadata:
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_kv_update_slices: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key, value, kv_cache, slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block)
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
@ -252,6 +254,7 @@ def write_to_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
@ -271,7 +274,7 @@ def write_to_kv_cache(
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv, slot_mapping, kv_cache, page_size,
|
||||
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
|
||||
num_slices_per_kv_cache_update_block)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
@ -279,32 +282,39 @@ def write_to_kv_cache(
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
|
||||
"int page_size, int num_slices_per_block) -> Tensor", )
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
|
||||
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
|
||||
"-> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
page_size, num_slices_per_block)
|
||||
num_kv_update_slices, page_size,
|
||||
num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
@ -713,8 +713,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device)
|
||||
block_tables = block_tables.to(self.device)
|
||||
|
||||
# Calculate the slot mapping
|
||||
slot_mapping_metadata = self._get_slot_mapping_metadata(
|
||||
num_reqs, num_scheduled_tokens_per_req)
|
||||
num_kv_update_slices = slot_mapping_metadata.shape[0]
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||
self.block_size)
|
||||
@ -745,6 +747,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_seqs=torch.tensor([num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
num_kv_update_slices=torch.tensor([num_kv_update_slices],
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
num_slices_per_kv_cache_update_block=
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||
)
|
||||
@ -1174,6 +1179,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32).to(self.device)
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
num_tokens, self.max_num_reqs, self.block_size)
|
||||
num_kv_update_slices = torch.tensor([padded_num_slices],
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||
dtype=torch.int32).to(self.device)
|
||||
block_tables = torch.zeros((num_reqs, num_blocks),
|
||||
@ -1193,6 +1200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_seqs,
|
||||
num_kv_update_slices=num_kv_update_slices,
|
||||
num_slices_per_kv_cache_update_block=
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||
)
|
||||
|
Reference in New Issue
Block a user