mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
|
||||
|
||||
def test_sliding_window_possible_cached_prefix():
|
||||
block_size = 2
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=2,
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
@ -44,7 +45,9 @@ def test_sliding_window_possible_cached_prefix():
|
||||
i: block_pool.blocks[i + 10]
|
||||
}
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hash_list,
|
||||
len(block_hash_list) * block_size)
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
assert all(block == block_pool.null_block
|
||||
|
@ -146,21 +146,16 @@ class KVCacheManager:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
|
||||
if len(block_hashes) * self.block_size == request.num_tokens:
|
||||
# When prompt length is divisible by the block size and all
|
||||
# blocks are cached, we need to recompute the last token. This
|
||||
# have to be achieved by re-computing an entire block because
|
||||
# allocate_slots() assumes num_computed_tokens is always a
|
||||
# multiple of the block size. To achieve this, remove the last
|
||||
# block hash from the block_hashes for find_longest_cache_hit
|
||||
# This limitation can potentially be removed in the future to
|
||||
# slightly improve the performance.
|
||||
last_block_hash = block_hashes.pop()
|
||||
else:
|
||||
last_block_hash = None
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
# the single last token, because allocate_slots() requires
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
|
||||
computed_blocks = (
|
||||
self.single_type_manager.find_longest_cache_hit(block_hashes))
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes, max_cache_hit_length)
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
@ -171,12 +166,6 @@ class KVCacheManager:
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
if last_block_hash is not None:
|
||||
# Add back the last block hash if it was removed.
|
||||
# NOTE: Because block_hashes is cached in req_to_block_hashes,
|
||||
# we shouldn't modify it directly.
|
||||
block_hashes.append(last_block_hash)
|
||||
|
||||
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
|
@ -187,17 +187,19 @@ class SingleTypeKVCacheManager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
||||
found, return an empty list. if eagle is enabled, drop the last matched
|
||||
block to force recompute the last block to get the required hidden
|
||||
states for eagle drafting head. Need to be customized for each attention
|
||||
type.
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. If no cache hit is found, return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block.
|
||||
For example, sliding window manager should return a list like
|
||||
@ -226,10 +228,12 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
computed_blocks: list[KVCacheBlock] = []
|
||||
for block_hash in block_hashes:
|
||||
max_num_blocks = max_length // self.block_size
|
||||
for i in range(max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
@ -276,19 +280,20 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
self.sliding_window_contiguous_blocks += 1
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(len(block_hashes)) to
|
||||
# O(len(block_hashes) / sliding_window_contiguous_blocks +
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
computed_blocks = [self._null_block] * len(block_hashes)
|
||||
max_num_blocks = max_length // self.block_size
|
||||
computed_blocks = [self._null_block] * max_num_blocks
|
||||
num_contiguous_blocks = 0
|
||||
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(len(block_hashes) - 1, -1, -1):
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hashes[i]):
|
||||
computed_blocks[i] = cached_block
|
||||
|
Reference in New Issue
Block a user