[v1] Move block management logic from KVCacheManager to SpecializedManager (#17474)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-09 23:25:34 +08:00
committed by GitHub
parent 9f64e93415
commit 200da9a517
6 changed files with 268 additions and 154 deletions

View File

@ -539,7 +539,7 @@ def test_allocate_with_lookahead():
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
@ -550,7 +550,7 @@ def test_allocate_with_lookahead():
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks.blocks) == 2
@ -561,7 +561,7 @@ def test_allocate_with_lookahead():
max_model_len=100)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_new_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks.blocks) == 2

View File

@ -299,7 +299,8 @@ def test_decode():
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 59
@ -309,8 +310,10 @@ def test_decode():
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
def test_evict():
@ -689,7 +692,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
@ -697,7 +700,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
manager.free(req1)

View File

@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
blocks = (scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
EXPECTED_TOTAL_BLOCKS)
assert (scheduler.kv_cache_manager.single_type_manager.
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (

View File

@ -8,6 +8,14 @@ from vllm.v1.core.specialized_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False,
num_kv_cache_groups=1,
caching_hash_fn=lambda x: x)
def test_sliding_window_possible_cached_prefix():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
@ -19,9 +27,7 @@ def test_sliding_window_possible_cached_prefix():
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
@ -81,9 +87,7 @@ def test_sliding_window_remove_skipped_blocks():
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
null_block_id = block_pool.null_block.block_id
@ -104,39 +108,35 @@ def test_sliding_window_remove_skipped_blocks():
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
removed = manager.remove_skipped_blocks(block_table, 0)
assert_block_id(removed, [])
manager.req_to_blocks["test"] = block_table
manager.remove_skipped_blocks("test", 0)
assert_block_id(block_table, original_block_ids)
# 4 tokens are computed. Only token 0 is out of the sliding window. As
# block 1000 also contains token 1 that is in the sliding window, block 1000
# cannot be removed.
removed = manager.remove_skipped_blocks(block_table, 4)
assert_block_id(removed, [])
manager.remove_skipped_blocks("test", 4)
assert_block_id(block_table, original_block_ids)
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
# Block 1000 can be removed.
removed = manager.remove_skipped_blocks(block_table, 5)
assert_block_id(removed, [original_block_ids[0]])
manager.remove_skipped_blocks("test", 5)
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 6 tokens are computed. Token 0-2 are out of the sliding window.
# Cannot remove new block as the block 1001 is still used by token 3.
removed = manager.remove_skipped_blocks(block_table, 6)
assert_block_id(removed, [])
manager.remove_skipped_blocks("test", 6)
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 7 tokens are computed. Token 0-3 are out of the sliding window.
# Block 1001 can be removed and block 1000 is already removed.
removed = manager.remove_skipped_blocks(block_table, 7)
assert_block_id(removed, [original_block_ids[1]])
manager.remove_skipped_blocks("test", 7)
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# 11 tokens are computed. Token 0-7 are out of the sliding window.
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
# sequence, and is expected to be evicted earlier than 1002, so the order
# of removed blocks should be [1003, 1002].
removed = manager.remove_skipped_blocks(block_table, 11)
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
manager.remove_skipped_blocks("test", 11)
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

View File

@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import cdiv, sha256
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens)
from vllm.v1.core.specialized_manager import get_specialized_manager
from vllm.v1.core.specialized_manager import get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@ -56,7 +55,6 @@ class KVCacheManager:
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
@ -68,30 +66,20 @@ class KVCacheManager:
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
enable_kv_cache_events)
self.specialized_manager = get_specialized_manager(
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
use_eagle=self.use_eagle,
num_kv_cache_groups=1,
caching_hash_fn=self.caching_hash_fn,
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: defaultdict[str,
list[KVCacheBlock]] = defaultdict(list)
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHashType]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
@property
def usage(self) -> float:
"""Get the KV cache usage.
@ -159,7 +147,7 @@ class KVCacheManager:
last_block_hash = None
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
self.single_type_manager.find_longest_cache_hit(block_hashes))
if self.log_stats:
assert self.prefix_cache_stats is not None
@ -181,7 +169,7 @@ class KVCacheManager:
def allocate_slots(
self,
request: Request,
num_tokens: int,
num_new_tokens: int,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
) -> Optional[KVCacheBlocks]:
@ -189,7 +177,7 @@ class KVCacheManager:
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate, including external
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: The new computed blocks just hitting the
@ -215,44 +203,38 @@ class KVCacheManager:
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0")
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
req_blocks = self.req_to_blocks[request.request_id]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks = self.specialized_manager.remove_skipped_blocks(
req_blocks, request.num_computed_tokens)
self.block_pool.free_blocks(removed_blocks)
self.single_type_manager.remove_skipped_blocks(
request.request_id, request.num_computed_tokens)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_block_list) * self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_block_list))
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1
for blk in new_computed_block_list
if blk.ref_cnt == 0)
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks):
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
@ -266,74 +248,33 @@ class KVCacheManager:
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_block_list)
self.single_type_manager.save_new_computed_blocks(
request.request_id, new_computed_block_list)
# Start to handle new blocks
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool.
num_new_blocks = min(
num_new_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
new_blocks = self.single_type_manager.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
if not self.enable_caching:
return KVCacheBlocks(new_blocks)
# Use `new_computed_block_list` for a new request, and
# `num_cached_block` for a running request.
num_cached_blocks = self.num_cached_block.get(
request.request_id, len(new_computed_block_list))
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
request.spec_token_ids)) // self.block_size
self.single_type_manager.cache_blocks(
request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - len(request.spec_token_ids))
self.block_pool.cache_full_blocks(
request=request,
blocks=req_blocks,
block_hashes=self.req_to_block_hashes[request.request_id],
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks_after_append,
block_size=self.block_size,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[
request.request_id] = num_full_blocks_after_append
return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
When caching is enabled, we free the blocks in reverse order so that
the tail blocks are evicted first.
We free the blocks in reverse order so that he tail blocks are evicted
first when caching is enabled.
Args:
request: The request to free the blocks.
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request.request_id, None)
self.single_type_manager.free(request.request_id)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
@ -390,14 +331,8 @@ class KVCacheManager:
int: The number of common prefix blocks.
"""
assert request.status == RequestStatus.RUNNING
blocks = self.req_to_blocks[request.request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
num_common_blocks += 1
else:
break
return num_common_blocks
return self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.

View File

@ -1,17 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.request import Request
class SpecializedManager(ABC):
class SingleTypeKVCacheManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
An abstract base class for a manager that handle the kv cache management
logic of one specific type of attention layer.
"""
def __init__(
@ -19,12 +22,18 @@ class SpecializedManager(ABC):
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
use_eagle: bool,
num_kv_cache_groups: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
use_eagle: Whether to use eagle.
num_kv_cache_groups: The number of kv cache groups managed by this
manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
@ -34,6 +43,149 @@ class SpecializedManager(ABC):
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: defaultdict[str,
list[KVCacheBlock]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.num_kv_cache_groups = num_kv_cache_groups
self.caching_hash_fn = caching_hash_fn
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
len(self.req_to_blocks[request_id]))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(blk.ref_cnt == 0
for blk in new_computed_blocks)
return ((num_new_blocks + num_evictable_computed_blocks) *
self.num_kv_cache_groups)
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(
num_new_blocks * self.num_kv_cache_groups)
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHashType],
num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
num_cached_blocks = self.num_cached_block[request.request_id]
num_full_blocks = num_tokens // self.block_size
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
@abstractmethod
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
raise NotImplementedError
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
@ -41,7 +193,8 @@ class SpecializedManager(ABC):
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list. if eagle is enabled, drop the last matched
block to force recompute the last block to get the required hidden
states for eagle drafting head.
states for eagle drafting head. Need to be customized for each attention
type.
Args:
block_hashes: The block hashes of the request.
@ -55,24 +208,23 @@ class SpecializedManager(ABC):
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Don't free the removed blocks in this function. Need to be customized
for each attention type.
Args:
blocks: The list of blocks to be updated.
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
class FullAttentionManager(SingleTypeKVCacheManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
@ -89,17 +241,28 @@ class FullAttentionManager(SpecializedManager):
computed_blocks.pop()
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
return []
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
num_common_blocks += 1
else:
break
return num_common_blocks
class SlidingWindowManager(SpecializedManager):
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
use_eagle: bool):
super().__init__(kv_cache_spec, block_pool, use_eagle)
use_eagle: bool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
@ -148,13 +311,13 @@ class SlidingWindowManager(SpecializedManager):
computed_blocks.pop()
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
@ -164,17 +327,27 @@ class SlidingWindowManager(SpecializedManager):
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
0 here for correctness. Need to support cascade attention + sliding
window in the future.
"""
return 0
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
**kwargs) -> SpecializedManager:
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec,
**kwargs) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, **kwargs)
return manager