mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Decode benh connector
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
349
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
349
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for DecodeBenchConnector.
|
||||
|
||||
Tests the functionality of the DecodeBenchConnector which fills KV cache
|
||||
with dummy values for decode performance benchmarking.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
# ruff: noqa: E501
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import (
|
||||
DecodeBenchConnector, DecodeBenchConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
class DecodeBenchTestRunner:
|
||||
"""Test runner for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, block_size: int, num_gpu_blocks: int):
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
self.req_id = -1
|
||||
|
||||
# Create vllm config with DecodeBenchConnector
|
||||
vllm_config = create_vllm_config(block_size=block_size,
|
||||
max_num_batched_tokens=1000)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="DecodeBenchConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler: Scheduler = create_scheduler(vllm_config,
|
||||
num_blocks=num_gpu_blocks)
|
||||
|
||||
# Create worker-side connector
|
||||
self.worker_connector = DecodeBenchConnector(vllm_config,
|
||||
KVConnectorRole.WORKER)
|
||||
|
||||
# Create dummy KV caches for testing
|
||||
# Shape: [num_blocks, 2, num_heads, block_size, head_dim]
|
||||
# Using simplified shape for testing
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
self.kv_caches = {
|
||||
f"layer_{i}":
|
||||
torch.zeros(num_gpu_blocks, 2, num_heads, block_size, head_dim)
|
||||
for i in range(2) # 2 layers for testing
|
||||
}
|
||||
|
||||
# Register KV caches with worker connector
|
||||
self.worker_connector.register_kv_caches(self.kv_caches)
|
||||
|
||||
# Extract scheduler-side connector
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, DecodeBenchConnector)
|
||||
self.scheduler_connector: DecodeBenchConnector = scheduler_connector
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0)
|
||||
|
||||
def new_request(self, token_ids: list[int]) -> Request:
|
||||
"""Create a new request with given token IDs."""
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=100),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
return req
|
||||
|
||||
def run_single_step(self, token_id: int = 0):
|
||||
"""Run a single scheduler + worker step."""
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
# Get connector metadata
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata)
|
||||
|
||||
# Bind metadata and load KV
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
# Create model runner output
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
return scheduler_output, kv_connector_metadata
|
||||
|
||||
|
||||
def test_decode_bench_connector_basic():
|
||||
"""Test basic functionality of DecodeBenchConnector."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with multiple blocks worth of tokens
|
||||
num_tokens = block_size * 3 # 3 blocks
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache with dummy values
|
||||
scheduler_output, metadata = runner.run_single_step()
|
||||
|
||||
# Check that get_num_new_matched_tokens returned correct value
|
||||
# Should be num_tokens - 1 (all except the last token for decode)
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
|
||||
# Check metadata has the request to fill
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify KV caches were filled with constant value
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
# Check that the block was filled
|
||||
block_data = kv_cache[block_id]
|
||||
# Should be filled with constant value 0.015
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_no_refill():
|
||||
"""Test that DecodeBenchConnector only fills once per request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request
|
||||
num_tokens = block_size * 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache
|
||||
_, metadata1 = runner.run_single_step()
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
|
||||
# Run second step - should NOT fill again (already filled)
|
||||
_, metadata2 = runner.run_single_step()
|
||||
assert len(metadata2.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_single_token():
|
||||
"""Test DecodeBenchConnector with single token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with just 1 token
|
||||
# Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode)
|
||||
token_ids = [1]
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run step - should NOT fill KV cache
|
||||
_, metadata = runner.run_single_step()
|
||||
assert len(metadata.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_two_tokens():
|
||||
"""Test DecodeBenchConnector with two token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with 2 tokens
|
||||
# Should fill 1 token (first token), decode the second
|
||||
token_ids = [1, 2]
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == 1
|
||||
assert len(block_ids) == 1 # 1 token needs 1 block
|
||||
|
||||
|
||||
def test_decode_bench_connector_large_context():
|
||||
"""Test DecodeBenchConnector with large context size."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 1000
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with many blocks
|
||||
num_blocks = 20
|
||||
num_tokens = block_size * num_blocks
|
||||
token_ids = list(range(num_tokens))
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify blocks were filled
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
block_data = kv_cache[block_id]
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_multiple_requests():
|
||||
"""Test DecodeBenchConnector with multiple sequential requests."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# First request
|
||||
req1 = runner.new_request([1] * (block_size * 2))
|
||||
_, metadata1 = runner.run_single_step()
|
||||
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
assert req1.request_id in metadata1.reqs_to_fill
|
||||
|
||||
# Complete first request
|
||||
while runner.scheduler.running:
|
||||
runner.run_single_step()
|
||||
|
||||
# Add EOS to finish
|
||||
scheduler_output = runner.scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=runner.scheduler.running,
|
||||
token_id=EOS_TOKEN_ID,
|
||||
use_eos=True,
|
||||
)
|
||||
runner.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Second request - should also get filled
|
||||
req2 = runner.new_request([2] * (block_size * 3))
|
||||
_, metadata2 = runner.run_single_step()
|
||||
|
||||
assert len(metadata2.reqs_to_fill) == 1
|
||||
assert req2.request_id in metadata2.reqs_to_fill
|
||||
|
||||
# Different request should have different metadata
|
||||
_, num_tokens1 = metadata1.reqs_to_fill[req1.request_id]
|
||||
_, num_tokens2 = metadata2.reqs_to_fill[req2.request_id]
|
||||
|
||||
assert num_tokens1 == block_size * 2 - 1
|
||||
assert num_tokens2 == block_size * 3 - 1
|
||||
|
||||
|
||||
def test_decode_bench_connector_partial_block():
|
||||
"""Test DecodeBenchConnector with partial block filling."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request that doesn't align to block boundaries
|
||||
# e.g., 2.5 blocks worth of tokens
|
||||
num_tokens = block_size * 2 + block_size // 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# Should allocate 3 blocks to hold the partial data
|
||||
expected_num_blocks = 3
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
@ -4,15 +4,15 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
# yapf: disable
|
||||
import vllm.envs as envs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase, KVConnectorBaseType)
|
||||
# yapf: enable
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
@ -111,3 +111,8 @@ KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"DecodeBenchConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector")
|
||||
|
@ -2,5 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501
|
||||
DecodeBenchConnector)
|
||||
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "DecodeBenchConnector"]
|
||||
|
@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DecodeBenchConnector: A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector emulates a prefill-decode disaggregated setting by filling
|
||||
the KV cache with dummy values, allowing measurement of decoder performance
|
||||
under larger input sequence lengths (ISL) in resource-limited environments.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
|
||||
"""Metadata for DecodeBenchConnector.
|
||||
|
||||
Contains information about which requests need their KV cache filled
|
||||
with dummy values for benchmarking purposes.
|
||||
"""
|
||||
# request_id -> (block_ids, num_tokens_to_fill)
|
||||
reqs_to_fill: dict[str, tuple[list[int], int]]
|
||||
|
||||
|
||||
class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector fills the KV cache with dummy (non-zero) values to
|
||||
emulate a prefill-decode disaggregated setting, enabling performance
|
||||
testing of the decoder with larger input sequence lengths.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config, role)
|
||||
|
||||
self.connector_scheduler: Optional[
|
||||
DecodeBenchConnectorScheduler] = None
|
||||
self.connector_worker: Optional[DecodeBenchConnectorWorker] = None
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = DecodeBenchConnectorScheduler(
|
||||
vllm_config)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs: Any) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
DecodeBenchConnectorMetadata)
|
||||
self.connector_worker.start_fill_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
# All operations are synchronous, so nothing to wait for
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any) -> None:
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[Optional[int], bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput") -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
|
||||
class DecodeBenchConnectorScheduler:
|
||||
"""Scheduler-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Track which requests have already been filled
|
||||
self._filled_requests: set[str] = set()
|
||||
|
||||
# Track pending fills for the current scheduler step
|
||||
# request_id -> (block_ids, num_tokens_to_fill)
|
||||
self._pending_fills: dict[str, tuple[list[int], int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For new requests, return the number of tokens that should be filled
|
||||
with dummy KV cache values.
|
||||
|
||||
Returns:
|
||||
(num_tokens_to_fill, is_async)
|
||||
- num_tokens_to_fill: total tokens in the request minus 1
|
||||
(we fill everything except the last token for decode)
|
||||
- is_async: False (synchronous filling)
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
# Only fill once per request on first scheduling
|
||||
if req_id in self._filled_requests or num_computed_tokens > 0:
|
||||
return 0, False
|
||||
|
||||
# Fill all tokens except the last one (which will be decoded)
|
||||
# This simulates having processed a long prefill
|
||||
num_tokens_to_fill = max(0, request.num_tokens - 1)
|
||||
|
||||
if num_tokens_to_fill == 0:
|
||||
return 0, False
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Request %s will fill %d tokens in KV cache",
|
||||
req_id, num_tokens_to_fill)
|
||||
|
||||
# Return False for synchronous operation - the fill is fast enough
|
||||
# that async overhead isn't worth it
|
||||
return num_tokens_to_fill, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Called after blocks are allocated. Store the block IDs so we can
|
||||
fill them with dummy values.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
# Get the block IDs that were allocated
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0] # Get first group (for single-tensor KV)
|
||||
|
||||
# Calculate how many blocks we need to fill
|
||||
# num_external_tokens are the tokens we said we'd provide
|
||||
num_blocks_to_fill = (num_external_tokens + self.block_size -
|
||||
1) // self.block_size
|
||||
|
||||
# Store the blocks to fill
|
||||
self._pending_fills[req_id] = (block_ids[:num_blocks_to_fill],
|
||||
num_external_tokens)
|
||||
self._filled_requests.add(req_id)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Allocated %d blocks for request %s",
|
||||
num_blocks_to_fill, req_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput") -> KVConnectorMetadata:
|
||||
"""
|
||||
Build metadata containing information about which blocks to fill
|
||||
with dummy KV values.
|
||||
"""
|
||||
meta = DecodeBenchConnectorMetadata(
|
||||
reqs_to_fill=self._pending_fills.copy())
|
||||
|
||||
# Clear pending fills after building metadata
|
||||
self._pending_fills.clear()
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
class DecodeBenchConnectorWorker:
|
||||
"""Worker-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Will be populated via register_kv_caches
|
||||
self.kv_caches: Optional[dict[str, torch.Tensor]] = None
|
||||
|
||||
# Cache for pre-filled dummy block to avoid repeated allocation
|
||||
self._dummy_block_cache: Optional[torch.Tensor] = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Store references to the KV cache tensors."""
|
||||
self.kv_caches = kv_caches
|
||||
logger.debug("DecodeBenchConnector: Registered %d KV cache layers",
|
||||
len(kv_caches))
|
||||
|
||||
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
|
||||
"""
|
||||
Fill the allocated KV cache blocks with dummy (non-zero) values.
|
||||
|
||||
This simulates having a populated KV cache from a prefill phase,
|
||||
allowing decode performance testing with larger context sizes.
|
||||
"""
|
||||
if not metadata.reqs_to_fill:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None, \
|
||||
"KV caches must be registered before filling"
|
||||
|
||||
for req_id, (block_ids, num_tokens) in metadata.reqs_to_fill.items():
|
||||
self._fill_blocks(block_ids, num_tokens)
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks (%d tokens) for "
|
||||
"request %s", len(block_ids), num_tokens, req_id)
|
||||
|
||||
def _fill_blocks(self, block_ids: list[int], num_tokens: int):
|
||||
"""
|
||||
Fill specified blocks with dummy non-zero values.
|
||||
|
||||
Args:
|
||||
block_ids: List of block IDs to fill
|
||||
num_tokens: Total number of tokens to fill across these blocks
|
||||
"""
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
# Fill each layer's KV cache with constant value
|
||||
assert self.kv_caches is not None
|
||||
for layer_name, kv_cache in self.kv_caches.items():
|
||||
# Create dummy block cache once per device/dtype
|
||||
if self._dummy_block_cache is None:
|
||||
block_shape = kv_cache.shape[1:]
|
||||
self._dummy_block_cache = torch.full(block_shape,
|
||||
0.015,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device)
|
||||
|
||||
# Convert block_ids to tensor on device
|
||||
block_ids_tensor = torch.tensor(block_ids,
|
||||
dtype=torch.long,
|
||||
device=kv_cache.device)
|
||||
|
||||
# Filter invalid block IDs
|
||||
valid_mask = block_ids_tensor < kv_cache.shape[0]
|
||||
valid_block_ids = block_ids_tensor[valid_mask]
|
||||
|
||||
if len(valid_block_ids) > 0:
|
||||
# Batch fill operation
|
||||
kv_cache[valid_block_ids] = self._dummy_block_cache
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks with dummy values",
|
||||
len(block_ids))
|
Reference in New Issue
Block a user