[v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Chen Zhang
2025-05-13 14:50:38 +08:00
committed by GitHub
parent e57e4d6e9e
commit f0d610a8ae
3 changed files with 35 additions and 38 deletions

View File

@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec(
block_size=2,
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
@ -44,7 +45,9 @@ def test_sliding_window_possible_cached_prefix():
i: block_pool.blocks[i + 10]
}
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block

View File

@ -146,21 +146,16 @@ class KVCacheManager:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# the single last token, because allocate_slots() requires
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks = (
self.single_type_manager.find_longest_cache_hit(block_hashes))
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes, max_cache_hit_length)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
@ -171,12 +166,6 @@ class KVCacheManager:
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots(

View File

@ -187,17 +187,19 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list. if eagle is enabled, drop the last matched
block to force recompute the last block to get the required hidden
states for eagle drafting head. Need to be customized for each attention
type.
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. If no cache hit is found, return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
@ -226,10 +228,12 @@ class SingleTypeKVCacheManager(ABC):
class FullAttentionManager(SingleTypeKVCacheManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
max_num_blocks = max_length // self.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
@ -276,19 +280,20 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
max_length: int) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block