mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Spec Decode] KV cache slots for eagle heads (#16370)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
@ -48,6 +49,18 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla)
|
||||
|
||||
|
||||
def test_none_hash():
|
||||
assert NONE_HASH is not None
|
||||
assert isinstance(NONE_HASH, int)
|
||||
@ -327,18 +340,6 @@ def test_metrics():
|
||||
|
||||
|
||||
def test_unify_kv_cache_configs():
|
||||
|
||||
def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla)
|
||||
|
||||
same_kv_cache_config = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
8 * GiB_bytes)
|
||||
assert estimated_max_len == want_estimated_max_len
|
||||
|
||||
|
||||
def test_allocate_with_lookahead():
|
||||
"""Verify that lookahead tokens correctly affect block allocation"""
|
||||
block_size = 4
|
||||
config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"],
|
||||
new_kv_cache_spec(block_size=block_size)),
|
||||
],
|
||||
)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
# Test case 1: Requires additional lookahead tokens
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||
max_model_len=100,
|
||||
num_preallocate_tokens=0)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
|
||||
)
|
||||
assert len(blocks) == 2 # ceil(5/4)=2 blocks
|
||||
|
||||
# Test case 2: With precomputed blocks
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||
max_model_len=100,
|
||||
num_preallocate_tokens=4)
|
||||
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
|
||||
# required_blocks = ceil((3 + 2) /4) = 2
|
||||
# total_blocks = 1 + 2 = 3
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=2,
|
||||
)
|
||||
assert len(blocks) == 3
|
||||
|
||||
# Test case 3: With precomputed blocks
|
||||
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
|
||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||
# total_blocks = 0 + 2 = 2
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||
max_model_len=100,
|
||||
num_preallocate_tokens=4)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=4,
|
||||
)
|
||||
assert len(blocks) == 2
|
||||
|
@ -164,7 +164,8 @@ class KVCacheManager:
|
||||
self,
|
||||
request: Request,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Optional[list[KVCacheBlock]] = None
|
||||
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
) -> Optional[list[KVCacheBlock]]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
@ -174,6 +175,9 @@ class KVCacheManager:
|
||||
not include the tokens that have already been computed.
|
||||
new_computed_blocks: A list of new computed blocks just hitting the
|
||||
prefix caching.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
|
||||
Blocks layout:
|
||||
-----------------------------------------------------------------------
|
||||
@ -211,8 +215,9 @@ class KVCacheManager:
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(new_computed_blocks) * self.block_size)
|
||||
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
|
||||
self.block_size)
|
||||
num_required_blocks = cdiv(
|
||||
num_computed_tokens + num_tokens + num_lookahead_tokens,
|
||||
self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||
len(new_computed_blocks))
|
||||
|
||||
@ -246,8 +251,11 @@ class KVCacheManager:
|
||||
else:
|
||||
# Get new blocks from the free block pool considering
|
||||
# preallocated blocks.
|
||||
num_preallocate_blocks = max(
|
||||
0, self.num_preallocate_blocks -
|
||||
num_lookahead_tokens // self.block_size)
|
||||
num_new_blocks = min(
|
||||
num_new_blocks + self.num_preallocate_blocks,
|
||||
num_new_blocks + num_preallocate_blocks,
|
||||
self.block_pool.get_num_free_blocks(),
|
||||
# Should not exceed the maximum number of blocks per request.
|
||||
# This is especially because the block table has the shape
|
||||
|
@ -7,7 +7,8 @@ from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
|
||||
self.encoder_cache_manager = EncoderCacheManager(
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
self.num_lookahead_tokens = 0
|
||||
if speculative_config and speculative_config.method == "eagle":
|
||||
self.num_lookahead_tokens = \
|
||||
speculative_config.num_speculative_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens)
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
|
@ -98,6 +98,7 @@ class EngineCore:
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
|
Reference in New Issue
Block a user