mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[v1] Move block management logic from KVCacheManager to SpecializedManager (#17474)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -539,7 +539,7 @@ def test_allocate_with_lookahead():
|
||||
max_model_len=100)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
|
||||
)
|
||||
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
|
||||
@ -550,7 +550,7 @@ def test_allocate_with_lookahead():
|
||||
# required_blocks = ceil((3 + 2) /4) = 2
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=2,
|
||||
)
|
||||
assert len(blocks.blocks) == 2
|
||||
@ -561,7 +561,7 @@ def test_allocate_with_lookahead():
|
||||
max_model_len=100)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=4,
|
||||
)
|
||||
assert len(blocks.blocks) == 2
|
||||
|
@ -299,7 +299,8 @@ def test_decode():
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 4)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
# Append slots with allocating a new block.
|
||||
req0.num_computed_tokens = 59
|
||||
@ -309,8 +310,10 @@ def test_decode():
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.allocate_slots(req0, 19)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
||||
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
req0.request_id][-2].block_hash is not None
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
|
||||
def test_evict():
|
||||
@ -689,7 +692,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48, computed_blocks)
|
||||
block_part0 = manager.req_to_blocks[req0.request_id]
|
||||
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
@ -697,7 +700,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert computed_blocks.blocks == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48, computed_blocks)
|
||||
block_part1 = manager.req_to_blocks[req1.request_id]
|
||||
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| ... |
|
||||
manager.free(req1)
|
||||
|
@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager(
|
||||
# Make sure the request stats are right.
|
||||
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
||||
for req_id in req_ids:
|
||||
blocks = scheduler.kv_cache_manager.req_to_blocks[req_id]
|
||||
blocks = (scheduler.kv_cache_manager.single_type_manager.
|
||||
req_to_blocks[req_id])
|
||||
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
|
||||
assert (scheduler.kv_cache_manager.num_cached_block[req_id] ==
|
||||
EXPECTED_TOTAL_BLOCKS)
|
||||
assert (scheduler.kv_cache_manager.single_type_manager.
|
||||
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
|
||||
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
||||
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
||||
|
||||
@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
|
@ -8,6 +8,14 @@ from vllm.v1.core.specialized_manager import SlidingWindowManager
|
||||
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||
|
||||
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
return SlidingWindowManager(sliding_window_spec,
|
||||
block_pool,
|
||||
use_eagle=False,
|
||||
num_kv_cache_groups=1,
|
||||
caching_hash_fn=lambda x: x)
|
||||
|
||||
|
||||
def test_sliding_window_possible_cached_prefix():
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=2,
|
||||
@ -19,9 +27,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
manager = SlidingWindowManager(sliding_window_spec,
|
||||
block_pool,
|
||||
use_eagle=False)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
block_hash_list = [
|
||||
@ -81,9 +87,7 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
|
||||
manager = SlidingWindowManager(sliding_window_spec,
|
||||
block_pool,
|
||||
use_eagle=False)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
null_block_id = block_pool.null_block.block_id
|
||||
|
||||
@ -104,39 +108,35 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
|
||||
]
|
||||
block_table = id_to_block_table(original_block_ids)
|
||||
removed = manager.remove_skipped_blocks(block_table, 0)
|
||||
assert_block_id(removed, [])
|
||||
manager.req_to_blocks["test"] = block_table
|
||||
|
||||
manager.remove_skipped_blocks("test", 0)
|
||||
assert_block_id(block_table, original_block_ids)
|
||||
|
||||
# 4 tokens are computed. Only token 0 is out of the sliding window. As
|
||||
# block 1000 also contains token 1 that is in the sliding window, block 1000
|
||||
# cannot be removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 4)
|
||||
assert_block_id(removed, [])
|
||||
manager.remove_skipped_blocks("test", 4)
|
||||
assert_block_id(block_table, original_block_ids)
|
||||
|
||||
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
|
||||
# Block 1000 can be removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 5)
|
||||
assert_block_id(removed, [original_block_ids[0]])
|
||||
manager.remove_skipped_blocks("test", 5)
|
||||
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
|
||||
|
||||
# 6 tokens are computed. Token 0-2 are out of the sliding window.
|
||||
# Cannot remove new block as the block 1001 is still used by token 3.
|
||||
removed = manager.remove_skipped_blocks(block_table, 6)
|
||||
assert_block_id(removed, [])
|
||||
manager.remove_skipped_blocks("test", 6)
|
||||
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
|
||||
|
||||
# 7 tokens are computed. Token 0-3 are out of the sliding window.
|
||||
# Block 1001 can be removed and block 1000 is already removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 7)
|
||||
assert_block_id(removed, [original_block_ids[1]])
|
||||
manager.remove_skipped_blocks("test", 7)
|
||||
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
|
||||
|
||||
# 11 tokens are computed. Token 0-7 are out of the sliding window.
|
||||
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
|
||||
# sequence, and is expected to be evicted earlier than 1002, so the order
|
||||
# of removed blocks should be [1003, 1002].
|
||||
removed = manager.remove_skipped_blocks(block_table, 11)
|
||||
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
|
||||
manager.remove_skipped_blocks("test", 11)
|
||||
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])
|
||||
|
@ -1,17 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.core.specialized_manager import get_specialized_manager
|
||||
from vllm.v1.core.specialized_manager import get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -56,7 +55,6 @@ class KVCacheManager:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_gpu_blocks = kv_cache_config.num_blocks
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
@ -68,30 +66,20 @@ class KVCacheManager:
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
|
||||
self.specialized_manager = get_specialized_manager(
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
use_eagle=self.use_eagle,
|
||||
num_kv_cache_groups=1,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
self.req_to_blocks: defaultdict[str,
|
||||
list[KVCacheBlock]] = defaultdict(list)
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
# `get_computed_blocks` or `allocate_slots`.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHashType]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
@ -159,7 +147,7 @@ class KVCacheManager:
|
||||
last_block_hash = None
|
||||
|
||||
computed_blocks = (
|
||||
self.specialized_manager.find_longest_cache_hit(block_hashes))
|
||||
self.single_type_manager.find_longest_cache_hit(block_hashes))
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
@ -181,7 +169,7 @@ class KVCacheManager:
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_tokens: int,
|
||||
num_new_tokens: int,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
@ -189,7 +177,7 @@ class KVCacheManager:
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_tokens: The number of tokens to allocate, including external
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
@ -215,44 +203,38 @@ class KVCacheManager:
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_tokens == 0:
|
||||
raise ValueError("num_tokens must be greater than 0")
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = []
|
||||
|
||||
req_blocks = self.req_to_blocks[request.request_id]
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
removed_blocks = self.specialized_manager.remove_skipped_blocks(
|
||||
req_blocks, request.num_computed_tokens)
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
self.single_type_manager.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(new_computed_block_list) * self.block_size)
|
||||
num_required_blocks = cdiv(
|
||||
num_computed_tokens + num_tokens + num_lookahead_tokens,
|
||||
self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||
len(new_computed_block_list))
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len)
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
))
|
||||
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
||||
# when allocating this request.
|
||||
num_evictable_computed_blocks = sum(1
|
||||
for blk in new_computed_block_list
|
||||
if blk.ref_cnt == 0)
|
||||
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
|
||||
num_evictable_computed_blocks):
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
@ -266,74 +248,33 @@ class KVCacheManager:
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
req_blocks.extend(new_computed_block_list)
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list)
|
||||
|
||||
# Start to handle new blocks
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
# No new block is needed.
|
||||
new_blocks = []
|
||||
else:
|
||||
# Get new blocks from the free block pool.
|
||||
num_new_blocks = min(
|
||||
num_new_blocks,
|
||||
self.block_pool.get_num_free_blocks(),
|
||||
# Should not exceed the maximum number of blocks per request.
|
||||
# This is especially because the block table has the shape
|
||||
# [..., max_num_blocks_per_req].
|
||||
self.max_num_blocks_per_req - len(req_blocks),
|
||||
)
|
||||
assert num_new_blocks > 0
|
||||
|
||||
# Concatenate the computed block IDs and the new block IDs.
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
if not self.enable_caching:
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
# Use `new_computed_block_list` for a new request, and
|
||||
# `num_cached_block` for a running request.
|
||||
num_cached_blocks = self.num_cached_block.get(
|
||||
request.request_id, len(new_computed_block_list))
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
|
||||
request.spec_token_ids)) // self.block_size
|
||||
self.single_type_manager.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - len(request.spec_token_ids))
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=req_blocks,
|
||||
block_hashes=self.req_to_block_hashes[request.request_id],
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks_after_append,
|
||||
block_size=self.block_size,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[
|
||||
request.request_id] = num_full_blocks_after_append
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
When caching is enabled, we free the blocks in reverse order so that
|
||||
the tail blocks are evicted first.
|
||||
We free the blocks in reverse order so that he tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
blocks = self.req_to_blocks.pop(request.request_id, [])
|
||||
ordered_blocks: Iterable[KVCacheBlock] = blocks
|
||||
if self.enable_caching:
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request.request_id, None)
|
||||
self.single_type_manager.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
@ -390,14 +331,8 @@ class KVCacheManager:
|
||||
int: The number of common prefix blocks.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
blocks = self.req_to_blocks[request.request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == num_running_requests:
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
return self.single_type_manager.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
@ -1,17 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SpecializedManager(ABC):
|
||||
class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
An abstract base class for specialized managers that handle the kv
|
||||
cache management logic of different attention layers.
|
||||
An abstract base class for a manager that handle the kv cache management
|
||||
logic of one specific type of attention layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -19,12 +22,18 @@ class SpecializedManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
use_eagle: bool,
|
||||
num_kv_cache_groups: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SpecializedManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
use_eagle: Whether to use eagle.
|
||||
num_kv_cache_groups: The number of kv cache groups managed by this
|
||||
manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
@ -34,6 +43,149 @@ class SpecializedManager(ABC):
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
self.req_to_blocks: defaultdict[str,
|
||||
list[KVCacheBlock]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.num_kv_cache_groups = num_kv_cache_groups
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
|
||||
len(self.req_to_blocks[request_id]))
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(blk.ref_cnt == 0
|
||||
for blk in new_computed_blocks)
|
||||
return ((num_new_blocks + num_evictable_computed_blocks) *
|
||||
self.num_kv_cache_groups)
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(
|
||||
num_new_blocks * self.num_kv_cache_groups)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHashType],
|
||||
num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(req_blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
block_hashes: The block hashes of the request.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
@ -41,7 +193,8 @@ class SpecializedManager(ABC):
|
||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
||||
found, return an empty list. if eagle is enabled, drop the last matched
|
||||
block to force recompute the last block to get the required hidden
|
||||
states for eagle drafting head.
|
||||
states for eagle drafting head. Need to be customized for each attention
|
||||
type.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
@ -55,24 +208,23 @@ class SpecializedManager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks`. The removed
|
||||
blocks should be replaced by null_block. Return the removed blocks in
|
||||
eviction order, where the first returned block should be evicted first.
|
||||
Don't free the removed blocks in this function.
|
||||
Don't free the removed blocks in this function. Need to be customized
|
||||
for each attention type.
|
||||
|
||||
Args:
|
||||
blocks: The list of blocks to be updated.
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
Returns:
|
||||
The removed blocks in eviction order.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FullAttentionManager(SpecializedManager):
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
@ -89,17 +241,28 @@ class FullAttentionManager(SpecializedManager):
|
||||
computed_blocks.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# No need to remove blocks for full attention.
|
||||
return []
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == num_running_requests:
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
|
||||
class SlidingWindowManager(SpecializedManager):
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
|
||||
use_eagle: bool):
|
||||
super().__init__(kv_cache_spec, block_pool, use_eagle)
|
||||
use_eagle: bool, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
@ -148,13 +311,13 @@ class SlidingWindowManager(SpecializedManager):
|
||||
computed_blocks.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the sliding window and
|
||||
# skipped during the attention computation.
|
||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
||||
last_useful_block = last_useful_token // self.block_size
|
||||
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
for i in range(last_useful_block - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
@ -164,17 +327,27 @@ class SlidingWindowManager(SpecializedManager):
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
return removed_blocks
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||
0 here for correctness. Need to support cascade attention + sliding
|
||||
window in the future.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
}
|
||||
|
||||
|
||||
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
||||
**kwargs) -> SpecializedManager:
|
||||
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec,
|
||||
**kwargs) -> SingleTypeKVCacheManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, **kwargs)
|
||||
return manager
|
||||
|
Reference in New Issue
Block a user