[v1] Refactor KVCacheConfig (#14079)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-03-21 19:56:27 +08:00
committed by GitHub
parent 61e8c18350
commit 93a00d7dde
10 changed files with 318 additions and 110 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
@ -8,7 +9,10 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -314,3 +318,107 @@ def test_metrics():
assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue
def test_unify_kv_cache_configs():
def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
]
unify_kv_cache_configs(same_kv_cache_config)
assert same_kv_cache_config[0].num_blocks == 10
assert same_kv_cache_config[1].num_blocks == 10
need_sort_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
],
),
]
unify_kv_cache_configs(need_sort_kv_cache_config)
assert need_sort_kv_cache_config[0].num_blocks == 10
assert need_sort_kv_cache_config[1].num_blocks == 10
diff_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=8)),
],
),
]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)

View File

@ -7,8 +7,8 @@ from typing import Any, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor)
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int,
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Raises:
@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
f"`max_model_len` when initializing the engine.")
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]]
assert all(
kv_cache_spec[layer_name] == layer_spec
for layer_name in layer_names_one_group[1:]), (
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec))
return kv_cache_groups
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The KVCacheSpec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
Returns:
True if all layers have the same type, False otherwise.
@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int,
num_layers: int) -> KVCacheConfig:
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
The generated KVCacheConfig
@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1
page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
max_model_len_str, max_concurrency)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
groups=[[layer_name for layer_name in kv_cache_spec]],
kv_cache_spec=kv_cache_spec)
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
return kv_cache_config
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: list[KVCacheSpec],
available_memory: int) -> list[KVCacheConfig]:
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_specs: The kv cache specs of the model
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
"""
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
kv_cache_configs = []
for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else:
raise NotImplementedError
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
raise NotImplementedError
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks = min(kv_cache_config.num_blocks
for kv_cache_config in kv_cache_configs)
for kv_cache_config in kv_cache_configs:
kv_cache_config.num_blocks = min_num_blocks
return kv_cache_configs

View File

@ -21,7 +21,8 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
@ -120,15 +121,27 @@ class EngineCore:
# memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory()
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
# Initialize kv cache and warmup the execution

View File

@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)
return output
def get_kv_cache_specs(self) -> list[KVCacheSpec]:
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output
@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> int: # in bytes
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()
return [memory_tensor.item()]

View File

@ -11,7 +11,7 @@ logger = init_logger(__name__)
@dataclass
class KVCacheSpecBase:
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
"""
@ -55,7 +55,7 @@ class KVCacheSpecBase:
@dataclass
class FullAttentionSpec(KVCacheSpecBase):
class FullAttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass
class KVCacheTensor:
"""
@ -89,6 +86,18 @@ class KVCacheTensor:
size: int # The size of KV cache Tensor in bytes
@dataclass
class KVCacheGroupSpec:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
@dataclass
class KVCacheConfig:
"""
@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer"""
tensors: dict[str, KVCacheTensor]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
is same across all groups (as the KVCacheManager only supports allocating
pages of the same size). For example:
1. A model only uses full attention: one group with all layers in the model.
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
groups: list[list[str]]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec
kv_cache_groups: list[KVCacheGroupSpec]

View File

@ -1510,34 +1510,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
raise NotImplementedError
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
@ -1549,7 +1561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: KVCacheSpec = {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue

View File

@ -185,7 +185,7 @@ class Worker(WorkerBase):
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -309,7 +309,7 @@ class TPUModelRunner:
assert self.model is not None
return self.model
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
@ -320,7 +320,7 @@ class TPUModelRunner:
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
@ -837,31 +837,33 @@ class TPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
tpu_k_cache = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,

View File

@ -189,7 +189,7 @@ class TPUWorker:
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:

View File

@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""Get specifications for KV cache implementation."""
raise NotImplementedError