mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] add kv cache update kernel (#19928)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
71
tests/v1/tpu/test_kv_cache_update_kernel.py
Normal file
71
tests/v1/tpu/test_kv_cache_update_kernel.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
import vllm.v1.attention.backends.pallas # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This is a test for TPU only")
|
||||
@pytest.mark.parametrize("page_size", [32, 33])
|
||||
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
|
||||
@pytest.mark.parametrize("head_dim", [128, 256])
|
||||
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
|
||||
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
head_dim: int, num_slices_per_block: int):
|
||||
page_num = 1000
|
||||
padded_num_tokens = 128
|
||||
kv_cache_cpu = torch.zeros(
|
||||
(page_num * page_size, combined_kv_head_num, head_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu")
|
||||
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
|
||||
new_kv_cpu = torch.randn(
|
||||
(padded_num_tokens, combined_kv_head_num, head_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu")
|
||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
||||
dtype=np.int32)
|
||||
kv_cache_start_indices = np.array([
|
||||
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
||||
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
||||
],
|
||||
dtype=np.int32)
|
||||
new_kv_cache_indices = np.concatenate(
|
||||
[np.array([0], dtype=np.int32),
|
||||
np.cumsum(slice_lens[:-1])])
|
||||
slot_mapping = np.stack(
|
||||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
|
||||
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
|
||||
1) // num_slices_per_block * num_slices_per_block
|
||||
slot_mapping = np.pad(slot_mapping,
|
||||
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
|
||||
constant_values=0)
|
||||
slot_mapping = np.transpose(slot_mapping)
|
||||
slot_mapping_cpu = torch.tensor(slot_mapping,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||
torch_xla.sync()
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
|
||||
num_slices_per_block)
|
||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||
torch_xla.sync()
|
||||
|
||||
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
|
||||
slice_lens):
|
||||
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
|
||||
|
||||
assert torch.allclose(kv_cache_xla.cpu(),
|
||||
kv_cache_cpu,
|
||||
atol=1e-4,
|
||||
rtol=1e-4)
|
@ -47,7 +47,7 @@ def test_ragged_paged_attention():
|
||||
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
|
||||
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
||||
max_num_reqs = 8
|
||||
max_num_blocks_per_req = 8
|
||||
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||
@ -65,6 +65,7 @@ def test_ragged_paged_attention():
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_seqs,
|
||||
num_slices_per_kv_cache_update_block=8,
|
||||
)
|
||||
|
||||
with patch("torch.ops.xla.ragged_paged_attention"
|
||||
|
117
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
117
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
|
||||
def _kv_cache_update_kernel(
|
||||
# Prefetch
|
||||
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
|
||||
# slice_len)
|
||||
# Input
|
||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
# Output
|
||||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
# Scratch
|
||||
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||
# head_dim]
|
||||
sem,
|
||||
):
|
||||
async_copies = []
|
||||
block_idx = pl.program_id(0)
|
||||
num_slices_per_block = scratch.shape[0]
|
||||
|
||||
# Copy from new_kv_hbm_ref to scratch
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
new_kv_start = slices_ref[1, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
async_copy = pltpu.make_async_copy(
|
||||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
# Copy from scratch to kv_cache_hbm_ref
|
||||
async_copies.clear()
|
||||
for i in range(num_slices_per_block):
|
||||
offset_i = i + block_idx * num_slices_per_block
|
||||
kv_cache_start = slices_ref[0, offset_i]
|
||||
length = slices_ref[2, offset_i]
|
||||
async_copy = pltpu.make_async_copy(
|
||||
scratch.at[i, pl.ds(0, length), ...],
|
||||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||
sem,
|
||||
)
|
||||
async_copy.start()
|
||||
async_copies.append(async_copy)
|
||||
for async_copy in async_copies:
|
||||
async_copy.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["page_size", "num_slices_per_block"],
|
||||
)
|
||||
def kv_cache_update(
|
||||
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||
slices: jax.
|
||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||
kv_cache: jax.
|
||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||
*,
|
||||
page_size: int = 32,
|
||||
num_slices_per_block: int = 8,
|
||||
):
|
||||
assert slices.shape[1] % num_slices_per_block == 0
|
||||
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||
assert kv_cache.shape[2] == head_dim
|
||||
assert head_dim % 128 == 0
|
||||
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||
# smaller or equal to page_size
|
||||
|
||||
in_specs = [
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||
]
|
||||
|
||||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||
|
||||
scalar_prefetches = [slices]
|
||||
scratch = pltpu.VMEM(
|
||||
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||
new_kv.dtype,
|
||||
)
|
||||
|
||||
scratch_shapes = [
|
||||
scratch,
|
||||
pltpu.SemaphoreType.DMA,
|
||||
]
|
||||
|
||||
kernel = pl.pallas_call(
|
||||
_kv_cache_update_kernel,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=(slices.shape[1] // num_slices_per_block, ),
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
out_shape=out_shape,
|
||||
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
@ -5,8 +5,12 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
# Required to register custom ops.
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
# Required to register custom ops.
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
@ -107,6 +111,7 @@ class PallasMetadata:
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||
write_to_kv_cache(
|
||||
key, value, kv_cache, slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
@ -244,6 +251,7 @@ def write_to_kv_cache(
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
@ -251,9 +259,9 @@ def write_to_kv_cache(
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
@ -262,4 +270,41 @@ def write_to_kv_cache(
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
kv_cache.index_copy_(0, slot_mapping, kv)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv, slot_mapping, kv_cache, page_size,
|
||||
num_slices_per_kv_cache_update_block)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
|
||||
"int page_size, int num_slices_per_block) -> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
page_size, num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
@ -53,12 +53,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
INVALID_TOKEN_ID = -1
|
||||
# Smallest output size
|
||||
MIN_NUM_SEQS = 8
|
||||
# Block size used for kv cache updating kernel
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
|
||||
|
||||
|
||||
#########################################################
|
||||
@ -526,6 +525,69 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _get_slot_mapping_metadata(self, num_reqs,
|
||||
num_scheduled_tokens_per_req):
|
||||
"""
|
||||
Computes metadata for mapping slots to blocks in the key-value (KV)
|
||||
cache for a batch of requests.
|
||||
|
||||
This function determines, for each request in the batch, how the
|
||||
scheduled tokens are distributed across memory blocks, and generates
|
||||
metadata needed to map slices of tokens to their corresponding positions
|
||||
in the KV cache.
|
||||
|
||||
Args:
|
||||
num_reqs (int): Number of requests in the current batch.
|
||||
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
|
||||
to be scheduled for each request.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 2D array of shape (total_block_len, 3), where each row
|
||||
contains:
|
||||
- kv_cache_start_index (int): The starting index in the KV cache
|
||||
for the corresponding slice.
|
||||
- new_kv_start_index (int): The starting index in the new KV
|
||||
cache for the corresponding slice.
|
||||
- slice_len (int): The length of the slice.
|
||||
"""
|
||||
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \
|
||||
num_scheduled_tokens_per_req
|
||||
local_block_start_idx = slices_start // self.block_size
|
||||
local_block_end_idx = (slices_end - 1) // self.block_size
|
||||
no_repeat_req_indices = self.arange_np[:num_reqs]
|
||||
global_block_start_idx = (
|
||||
no_repeat_req_indices * self.max_num_blocks_per_req +
|
||||
local_block_start_idx)
|
||||
block_lens = local_block_end_idx - local_block_start_idx + 1
|
||||
global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
|
||||
slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
|
||||
global_block_indices = global_block_start_idx + slice_arange
|
||||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
|
||||
total_block_len = np.sum(block_lens)
|
||||
slot_mapping_slices = np.repeat(np.array([[0, self.block_size]],
|
||||
dtype=np.int32),
|
||||
total_block_len,
|
||||
axis=0)
|
||||
cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
|
||||
np.cumsum(block_lens, out=cu_block_lens[1:])
|
||||
for req_idx in range(num_reqs):
|
||||
slot_mapping_slices[cu_block_lens[req_idx]][
|
||||
0] = slices_start[req_idx] % self.block_size
|
||||
slot_mapping_slices[
|
||||
cu_block_lens[req_idx + 1] -
|
||||
1][1] = (slices_end[req_idx] - 1) % self.block_size + 1
|
||||
slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
|
||||
cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
|
||||
np.cumsum(slice_lens, out=cu_slices_lens[1:])
|
||||
kv_cache_start_indices = slot_mapping_slices[:, 0] + \
|
||||
(block_numbers * self.block_size)
|
||||
new_kv_start_indices = cu_slices_lens[:-1]
|
||||
slot_mapping_metadata = np.stack(
|
||||
[kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1)
|
||||
return slot_mapping_metadata
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
|
||||
start_index: int):
|
||||
assert scheduler_output.total_num_scheduled_tokens > 0
|
||||
@ -603,26 +665,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Calculate the slot mapping.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
||||
# because M (max_model_len) is not necessarily divisible by block_size.
|
||||
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.input_batch.block_table[0].
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens_per_req,
|
||||
@ -645,12 +687,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.position_ids = self.positions_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.input_batch.block_table[0].slot_mapping_cpu[
|
||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = (
|
||||
self.input_batch.block_table[0].
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
if use_max_model_len:
|
||||
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
|
||||
self.max_num_blocks_per_req]
|
||||
@ -675,6 +711,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device)
|
||||
block_tables = block_tables.to(self.device)
|
||||
|
||||
slot_mapping_metadata = self._get_slot_mapping_metadata(
|
||||
num_reqs, num_scheduled_tokens_per_req)
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||
self.block_size)
|
||||
slot_mapping_metadata = np.pad(
|
||||
slot_mapping_metadata,
|
||||
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
|
||||
constant_values=0)
|
||||
slot_mapping_metadata = np.transpose(slot_mapping_metadata)
|
||||
slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
|
||||
device=self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
padded_num_scheduled_tokens_per_req = np.copy(
|
||||
@ -687,13 +736,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
padded_num_scheduled_tokens_per_req)
|
||||
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
slot_mapping=slot_mapping_metadata,
|
||||
block_tables=block_tables,
|
||||
context_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=torch.tensor([num_reqs],
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
num_slices_per_kv_cache_update_block=
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
@ -1119,8 +1170,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
actual_num_reqs = min(num_tokens, num_reqs)
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros(num_tokens,
|
||||
dtype=torch.int64).to(self.device)
|
||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||
num_tokens, self.max_num_reqs, self.block_size)
|
||||
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||
dtype=torch.int32).to(self.device)
|
||||
block_tables = torch.zeros((num_reqs, num_blocks),
|
||||
dtype=torch.int32).to(self.device)
|
||||
query_lens = [1] * num_reqs
|
||||
@ -1138,6 +1191,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_seqs,
|
||||
num_slices_per_kv_cache_update_block=
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
return paddings[index]
|
||||
|
||||
|
||||
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
||||
page_size: int) -> int:
|
||||
"""Calculates the padded number of KV cache update slices to avoid
|
||||
recompilation."""
|
||||
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
|
||||
padded_num_slices = min(padded_num_slices, num_tokens)
|
||||
padded_num_slices = (
|
||||
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
|
||||
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
|
||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
|
||||
return padded_num_slices
|
||||
|
||||
|
||||
def replace_set_lora(model):
|
||||
|
||||
def _tpu_set_lora(
|
||||
|
Reference in New Issue
Block a user