[KV offload][4/N] Offloading KV connector (#22595)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-09-19 22:07:17 +03:00
committed by GitHub
parent b716ab93a7
commit c59a0eca42
6 changed files with 1111 additions and 1 deletions

View File

@ -0,0 +1,505 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock
import pytest
import torch
from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector, OffloadingConnectorMetadata)
from vllm.forward_context import ForwardContext
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
OffloadingManager, PrepareStoreOutput)
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
TransferResult, TransferSpec)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
create_vllm_config)
class MockLoadStoreSpec(LoadStoreSpec):
def __init__(self, block_hashes: Iterable[BlockHash]):
self.block_hashes: list[BlockHash] = list(block_hashes)
@staticmethod
def medium() -> str:
return "Mock"
def __repr__(self) -> str:
return repr(self.block_hashes)
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
self.completed_transfers = []
return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
return True
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
block_hashes))
self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager:
return self.manager
def get_handlers(
self, _
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
return specs
@dataclass
class TransferSummary:
gpu_block_indices: list[int]
offload_addresses: list[Any]
class RequestRunner:
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
num_gpu_blocks: int):
self.offloaded_block_size: int = offloaded_block_size
self.gpu_block_size: int = gpu_block_size
self.num_gpu_blocks: int = num_gpu_blocks
self.req_id: int = -1
vllm_config = create_vllm_config(block_size=gpu_block_size,
max_num_batched_tokens=1000)
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"spec_name": "MockOffloadingSpec",
"spec_module_path":
"tests.v1.kv_connector.unit.test_offloading_connector",
"block_size": offloaded_block_size,
})
self.scheduler: Scheduler = create_scheduler(vllm_config,
num_blocks=num_gpu_blocks)
self.worker_connector = OffloadingConnector(vllm_config,
KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_kv_caches(
kv_caches={"a": torch.empty(0)})
# extract connector of scheduler
scheduler_connector = self.scheduler.connector
assert scheduler_connector is not None
assert isinstance(scheduler_connector, OffloadingConnector)
self.scheduler_connector: OffloadingConnector = scheduler_connector
# extract mocked OffloadingManager of scheduler connector
connector_scheduler = scheduler_connector.connector_scheduler
assert connector_scheduler is not None
manager = connector_scheduler.manager
assert isinstance(manager, MagicMock)
self.manager: MagicMock = manager
assert connector_scheduler.gpu_block_size == gpu_block_size
assert connector_scheduler.offloaded_block_size == offloaded_block_size
# extract OffloadingSpec of worker_connector
connector_worker = self.worker_connector.connector_worker
assert connector_worker is not None
offloading_spec = connector_worker.spec
assert isinstance(offloading_spec, MockOffloadingSpec)
self.offloading_spec: MockOffloadingSpec = offloading_spec
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
init_none_hash(sha256)
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
attn_metadata={},
virtual_engine=0)
def new_request(self, token_ids: list[int]):
assert not self.scheduler.requests
self.req_id += 1
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=1000),
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
self.scheduler.add_request(req)
def _wait_for_transfers(self):
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in (
self.offloading_spec.get_completed_transfers()):
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(
self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses))
self.pending_stores_count -= 1
else:
remainder_sub_block_count = (len(offload_addresses) -
len(gpu_block_indices))
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[
remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses))
self.pending_loads_count -= 1
def _update_gpu_block_idx(self):
for blocks in (self.scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks.values()):
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
"""
tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
assert self.scheduler.requests
scheduler_output = self.scheduler.schedule()
self._update_gpu_block_idx()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata,
OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(
kv_connector_metadata.reqs_to_store)
self.worker_connector.bind_connector_metadata(
kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
finished_sending, finished_recving = (
self.worker_connector.get_finished(
scheduler_output.finished_req_ids))
self.worker_connector.clear_connector_metadata()
model_runner_output = create_model_runner_output(
reqs=self.scheduler.running,
finished_sending=list(finished_sending),
finished_recving=list(finished_recving),
token_id=token_id)
if self.scheduler.running:
token_id = next(tokens_iter, None)
self.scheduler.update_from_output(scheduler_output,
model_runner_output)
self._wait_for_transfers()
# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
assert not self.scheduler.running
while self.scheduler.requests:
scheduler_output = self.scheduler.schedule()
finished_sending, finished_recving = (
self.worker_connector.get_finished(
scheduler_output.finished_req_ids))
assert not finished_recving
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending)
self.scheduler.update_from_output(scheduler_output,
model_runner_output)
def run(
self,
decoded_tokens: list[int],
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
"""
self.manager.reset_mock()
self._run(decoded_tokens)
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses):
loaded_gpu_block_indexes.add(gpu_block_idx)
assert gpu_block_idx == self.offloaded[offloaded_address]
assert (
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
self.completed_loads.clear()
stored_gpu_block_indexes: set[int] = set()
for transfer in self.completed_stores:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses):
stored_gpu_block_indexes.add(gpu_block_idx)
self.offloaded[offloaded_address] = gpu_block_idx
assert (
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
self.completed_stores.clear()
@pytest.fixture
def request_runner():
runners = []
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks)
runners.append(runner)
return runner
yield runner_factory # pass factory to the test
def generate_store_output(block_hashes: Iterable[BlockHash]):
block_hashes = list(block_hashes)
return PrepareStoreOutput(
block_hashes_to_store=list(block_hashes),
store_spec=MockLoadStoreSpec(block_hashes),
block_hashes_evicted=[],
)
def test_offloading_connector(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
block_size_factor = offloaded_block_size // gpu_block_size
runner = request_runner(offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks)
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
# add block missing 1 token -> no offload
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
runner.manager.prepare_store.side_effect = \
lambda block_hashes: None
runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called()
# 1 more block, now set block_hashes_to_store = []
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output([])
runner.run(decoded_tokens=[0] * offloaded_block_size)
# 1 more block, now check touch was called with all 6 blocks
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output(block_hashes)
runner.run(decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17))
runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
# create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run(decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(
range(6 * block_size_factor)))
runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6
# verify hashes are the same, except for the last block
assert block_hashes1[:5] == block_hashes2[:5]
assert block_hashes1[5] != block_hashes2[5]
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
(offloaded_block_size - gpu_block_size))
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output([])
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_not_called()
# single block lookup with no hits
runner.new_request(token_ids=[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output([])
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
# single block lookup with a hit
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output([])
runner.manager.lookup.return_value = 1
runner.run(decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2))
# single block lookup with a hit in a middle block
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
lambda block_hashes: generate_store_output([])
runner.manager.lookup.return_value = 1
runner.run(decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(3, 4, 5))
# test take_events
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]
def take_events() -> Iterable[OffloadingEvent]:
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
block_size=16,
medium="A",
removed=False)
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
block_size=32,
medium="B",
removed=True)
runner.manager.take_events.side_effect = take_events
events = list(runner.scheduler_connector.take_events())
assert len(events) == 2
event = events[0]
assert isinstance(event, BlockStored)
assert event.block_hashes == to_hashes([1, 2, 3])
assert event.block_size == 16
assert event.medium == "A"
assert event.token_ids == []
assert event.parent_block_hash is None
assert event.lora_id is None
event = events[1]
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"

View File

@ -176,6 +176,7 @@ def create_model_runner_output(
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
token_id: int = 0,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
@ -184,7 +185,7 @@ def create_model_runner_output(
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token = EOS_TOKEN_ID if use_eos else token_id
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = None if (

View File

@ -106,3 +106,8 @@ KVConnectorFactory.register_connector(
"MultiConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
"MultiConnector")
KVConnectorFactory.register_connector(
"OffloadingConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
"OffloadingConnector")

View File

@ -0,0 +1,485 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any, Optional
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request
ReqId = str
logger = init_logger(__name__)
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
class OffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config, role)
spec = OffloadingSpecFactory.create_spec(vllm_config)
self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None
self.connector_worker: Optional[OffloadingConnectorWorker] = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = OffloadingConnectorScheduler(spec)
elif role == KVConnectorRole.WORKER:
self.connector_worker = OffloadingConnectorWorker(spec)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
OffloadingConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def update_connector_output(self, connector_output: KVConnectorOutput):
assert self.connector_scheduler is not None
self.connector_scheduler.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def take_events(self) -> Iterable[KVCacheEvent]:
assert self.connector_scheduler is not None
return self.connector_scheduler.take_events()
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
self.gpu_block_size = spec.gpu_block_size
self.offloaded_block_size = spec.offloaded_block_size
self.block_size_factor = (self.offloaded_block_size //
self.gpu_block_size)
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: Optional[int] = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor)
def get_num_new_matched_tokens(
self, request: Request,
num_computed_tokens: int) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert (len(request.block_hashes) //
self.block_size_factor == num_blocks)
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx))
if hits == 0:
return 0, False
num_hit_tokens = (self.offloaded_block_size *
(start_block_idx + hits) - num_computed_tokens)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
return num_hit_tokens, True
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
num_external_tokens: int):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(block.block_hash is not None
for block in blocks.blocks[0])
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert (num_external_tokens == num_pending_gpu_blocks *
self.gpu_block_size)
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert (len(request.block_hashes) // self.block_size_factor
>= num_blocks)
block_hashes = self._get_block_hashes(request,
start_idx=start_block_idx,
end_idx=num_blocks)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
block_hashes = self._get_block_hashes(request,
start_idx=start_block_idx,
end_idx=num_blocks)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
self._reqs_being_loaded[request.request_id].update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(
scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
total_tokens = req.num_computed_tokens + new_tokens
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
num_gpu_blocks = num_blocks * self.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning("Cannot store %s blocks", num_new_blocks)
break
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output))
self._reqs_to_load = {}
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes,
medium=event.medium)
else:
yield BlockStored(block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium)
class OffloadingConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, spec: OffloadingSpec):
self.spec = spec
self.worker = OffloadingWorker()
self._job_counter = 0
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
self._finished_reqs_waiting_for_store: set[ReqId] = set()
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter = job_id + 1
return job_id
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)):
self.worker.register_handler(src_cls, dst_cls, handler)
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
assert self.worker.transfer_async(job_id, transfer_spec)
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
assert self.worker.transfer_async(job_id, transfer_spec)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending = set()
finished_recving = set()
for job_id, success in self.worker.get_finished():
# we currently do not support job failures
assert success
req_id, store = self._jobs.pop(job_id)
if store:
req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id)
if req_jobs:
continue
if req_id in self._finished_reqs_waiting_for_store:
self._finished_reqs_waiting_for_store.remove(req_id)
finished_sending.add(req_id)
del self._store_jobs[req_id]
else:
req_job = self._load_job[req_id]
assert job_id == req_job
del self._load_job[req_id]
finished_recving.add(req_id)
for req_id in finished_req_ids:
pending_req_jobs = self._store_jobs.get(req_id)
if pending_req_jobs:
self._finished_reqs_waiting_for_store.add(req_id)
elif pending_req_jobs is not None:
finished_sending.add(req_id)
del self._store_jobs[req_id]
return finished_sending, finished_recving
def yield_req_data(
scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids,
cached_reqs.resumed_from_preemption)

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import TYPE_CHECKING, Callable
from vllm.logger import init_logger
from vllm.v1.kv_offload.spec import OffloadingSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class OffloadingSpecFactory:
_registry: dict[str, Callable[[], type[OffloadingSpec]]] = {}
@classmethod
def register_spec(cls, name: str, module_path: str,
class_name: str) -> None:
"""Register a spec with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[OffloadingSpec]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_spec(
cls,
config: "VllmConfig",
) -> OffloadingSpec:
kv_transfer_config = config.kv_transfer_config
assert kv_transfer_config is not None
extra_config = kv_transfer_config.kv_connector_extra_config
spec_name = extra_config.get("spec_name", "CPUOffloadingSpec")
if spec_name in cls._registry:
spec_cls = cls._registry[spec_name]()
else:
spec_module_path = extra_config.get("spec_module_path")
if spec_module_path is None:
raise ValueError(f"Unsupported spec type: {spec_name}")
spec_module = importlib.import_module(spec_module_path)
spec_cls = getattr(spec_module, spec_name)
assert issubclass(spec_cls, OffloadingSpec)
logger.info("Creating offloading spec with name: %s", spec_name)
return spec_cls(config)
# Register various specs here.

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class OffloadingSpec(ABC):
"""Spec for an offloading connector"""
def __init__(self, vllm_config: "VllmConfig"):
logger.warning(
"Initializing OffloadingSpec. This API is experimental and "
"subject to change in the future as we iterate the design.")
self.vllm_config = vllm_config
kv_transfer_config = vllm_config.kv_transfer_config
assert kv_transfer_config is not None
self.extra_config = kv_transfer_config.kv_connector_extra_config
self.gpu_block_size = vllm_config.cache_config.block_size
self.offloaded_block_size = int(
self.extra_config.get("block_size", self.gpu_block_size))
assert self.offloaded_block_size % self.gpu_block_size == 0
@abstractmethod
def get_manager(self) -> OffloadingManager:
"""
Get an OffloadingManager that will be used
by the scheduler-side offloading connector to track
offloaded blocks and manage evictions.
"""
pass
@abstractmethod
def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.
Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
Yields:
Tuples of (src_type, dst_type, offloading_handler).
"""
pass