diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py new file mode 100644 index 0000000000..f9a4d2fb4d --- /dev/null +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm import SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( + OffloadingConnector, OffloadingConnectorMetadata) +from vllm.forward_context import ForwardContext +from vllm.utils import sha256 +from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher, + init_none_hash) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, + OffloadingManager, PrepareStoreOutput) +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, + TransferResult, TransferSpec) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.request import Request + +from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler, + create_vllm_config) + + +class MockLoadStoreSpec(LoadStoreSpec): + + def __init__(self, block_hashes: Iterable[BlockHash]): + self.block_hashes: list[BlockHash] = list(block_hashes) + + @staticmethod + def medium() -> str: + return "Mock" + + def __repr__(self) -> str: + return repr(self.block_hashes) + + +class MockOffloadingHandler(OffloadingHandler): + + def __init__(self): + self.completed_transfers: list[TransferResult] = [] + self.completed_specs: list[TransferSpec] = [] + + def get_finished(self) -> list[TransferResult]: + finished = self.completed_transfers + self.completed_transfers = [] + return finished + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + self.completed_specs.append(spec) + self.completed_transfers.append((job_id, True)) + return True + + +class MockOffloadingSpec(OffloadingSpec): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + self.manager = MagicMock(spec=OffloadingManager) + self.manager.lookup.return_value = 0 + self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec( + block_hashes)) + self.handler = MockOffloadingHandler() + + def get_manager(self) -> OffloadingManager: + return self.manager + + def get_handlers( + self, _ + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + OffloadingHandler]]: + + yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler + yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler + + def get_completed_transfers(self) -> list[TransferSpec]: + specs = self.handler.completed_specs + self.handler.completed_specs = [] + return specs + + +@dataclass +class TransferSummary: + gpu_block_indices: list[int] + offload_addresses: list[Any] + + +class RequestRunner: + + def __init__(self, offloaded_block_size: int, gpu_block_size: int, + num_gpu_blocks: int): + self.offloaded_block_size: int = offloaded_block_size + self.gpu_block_size: int = gpu_block_size + self.num_gpu_blocks: int = num_gpu_blocks + + self.req_id: int = -1 + + vllm_config = create_vllm_config(block_size=gpu_block_size, + max_num_batched_tokens=1000) + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "spec_name": "MockOffloadingSpec", + "spec_module_path": + "tests.v1.kv_connector.unit.test_offloading_connector", + "block_size": offloaded_block_size, + }) + + self.scheduler: Scheduler = create_scheduler(vllm_config, + num_blocks=num_gpu_blocks) + self.worker_connector = OffloadingConnector(vllm_config, + KVConnectorRole.WORKER) + + # register worker kv_caches to enable OffloadingWorker creations + self.worker_connector.register_kv_caches( + kv_caches={"a": torch.empty(0)}) + + # extract connector of scheduler + scheduler_connector = self.scheduler.connector + assert scheduler_connector is not None + assert isinstance(scheduler_connector, OffloadingConnector) + self.scheduler_connector: OffloadingConnector = scheduler_connector + + # extract mocked OffloadingManager of scheduler connector + connector_scheduler = scheduler_connector.connector_scheduler + assert connector_scheduler is not None + manager = connector_scheduler.manager + assert isinstance(manager, MagicMock) + self.manager: MagicMock = manager + + assert connector_scheduler.gpu_block_size == gpu_block_size + assert connector_scheduler.offloaded_block_size == offloaded_block_size + + # extract OffloadingSpec of worker_connector + connector_worker = self.worker_connector.connector_worker + assert connector_worker is not None + offloading_spec = connector_worker.spec + assert isinstance(offloading_spec, MockOffloadingSpec) + self.offloading_spec: MockOffloadingSpec = offloading_spec + + # mapping (offloading address) -> gpu_block_index + self.offloaded: dict[Any, int] = {} + + self.pending_loads_count: int = 0 + self.pending_stores_count: int = 0 + + self.completed_loads: list[TransferSummary] = [] + self.completed_stores: list[TransferSummary] = [] + + # maps {block_id: block_offset} + self.gpu_block_index: dict[int, int] = {} + + init_none_hash(sha256) + self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) + + self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={}, + attn_metadata={}, + virtual_engine=0) + + def new_request(self, token_ids: list[int]): + assert not self.scheduler.requests + self.req_id += 1 + + req = Request( + request_id=str(self.req_id), + prompt_token_ids=token_ids, + sampling_params=SamplingParams(max_tokens=1000), + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + block_hasher=self._block_hasher, + ) + + self.scheduler.add_request(req) + + def _wait_for_transfers(self): + block_size_factor = self.offloaded_block_size // self.gpu_block_size + + while self.pending_loads_count or self.pending_stores_count: + for transfer_spec in ( + self.offloading_spec.get_completed_transfers()): + src_spec, dst_spec = transfer_spec + + if isinstance(src_spec, GPULoadStoreSpec): + store = True + gpu_spec = src_spec + offload_spec = dst_spec + else: + store = False + gpu_spec = dst_spec + offload_spec = src_spec + + assert isinstance(offload_spec, MockLoadStoreSpec) + assert isinstance(gpu_spec, GPULoadStoreSpec) + + gpu_block_indices: list[int] = [] + for block_id in gpu_spec.block_ids: + gpu_block_indices.append( + self.gpu_block_index[block_id.item()]) + + # list of (block_hash, sub_block_offset) + offload_addresses: list[Any] = [] + for block_hash in offload_spec.block_hashes: + for sub_block_idx in range(block_size_factor): + offload_addresses.append((block_hash, sub_block_idx)) + + if store: + assert len(gpu_block_indices) == len(offload_addresses) + + self.completed_stores.append( + TransferSummary(gpu_block_indices, offload_addresses)) + self.pending_stores_count -= 1 + else: + remainder_sub_block_count = (len(offload_addresses) - + len(gpu_block_indices)) + assert remainder_sub_block_count >= 0 + assert remainder_sub_block_count < block_size_factor + offload_addresses = offload_addresses[ + remainder_sub_block_count:] + + self.completed_loads.append( + TransferSummary(gpu_block_indices, offload_addresses)) + self.pending_loads_count -= 1 + + def _update_gpu_block_idx(self): + for blocks in (self.scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks.values()): + for block_idx, block in enumerate(blocks): + self.gpu_block_index[block.block_id] = block_idx + + def _run(self, decoded_tokens: list[int]): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + """ + + tokens_iter = iter(decoded_tokens) + token_id = next(tokens_iter, None) + while token_id is not None: + assert self.scheduler.requests + + scheduler_output = self.scheduler.schedule() + self._update_gpu_block_idx() + + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, + OffloadingConnectorMetadata) + + self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) + self.pending_stores_count += len( + kv_connector_metadata.reqs_to_store) + + self.worker_connector.bind_connector_metadata( + kv_connector_metadata) + self.worker_connector.start_load_kv(self._dummy_ctx) + + if scheduler_output.total_num_scheduled_tokens > 0: + self.worker_connector.wait_for_save() + + finished_sending, finished_recving = ( + self.worker_connector.get_finished( + scheduler_output.finished_req_ids)) + + self.worker_connector.clear_connector_metadata() + + model_runner_output = create_model_runner_output( + reqs=self.scheduler.running, + finished_sending=list(finished_sending), + finished_recving=list(finished_recving), + token_id=token_id) + + if self.scheduler.running: + token_id = next(tokens_iter, None) + + self.scheduler.update_from_output(scheduler_output, + model_runner_output) + + self._wait_for_transfers() + + # run one more step to update finished stored + if EOS_TOKEN_ID in decoded_tokens: + assert not self.scheduler.running + + while self.scheduler.requests: + scheduler_output = self.scheduler.schedule() + + finished_sending, finished_recving = ( + self.worker_connector.get_finished( + scheduler_output.finished_req_ids)) + + assert not finished_recving + + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending) + + self.scheduler.update_from_output(scheduler_output, + model_runner_output) + + def run( + self, + decoded_tokens: list[int], + expected_stored_gpu_block_indexes: tuple[int, ...] = (), + expected_loaded_gpu_block_indexes: tuple[int, ...] = (), + ): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + expected_stored_gpu_block_indexes: GPU block indexes + that are expected to be written during the run. + expected_loaded_gpu_block_indexes: GPU block indexes + that are expected to be loaded during the run. + """ + + self.manager.reset_mock() + self._run(decoded_tokens) + + loaded_gpu_block_indexes: set[int] = set() + for transfer in self.completed_loads: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses): + loaded_gpu_block_indexes.add(gpu_block_idx) + assert gpu_block_idx == self.offloaded[offloaded_address] + + assert ( + set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes) + self.completed_loads.clear() + + stored_gpu_block_indexes: set[int] = set() + for transfer in self.completed_stores: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses): + stored_gpu_block_indexes.add(gpu_block_idx) + self.offloaded[offloaded_address] = gpu_block_idx + + assert ( + set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes) + self.completed_stores.clear() + + +@pytest.fixture +def request_runner(): + runners = [] + + def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks): + runner = RequestRunner(offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks) + runners.append(runner) + return runner + + yield runner_factory # pass factory to the test + + +def generate_store_output(block_hashes: Iterable[BlockHash]): + block_hashes = list(block_hashes) + return PrepareStoreOutput( + block_hashes_to_store=list(block_hashes), + store_spec=MockLoadStoreSpec(block_hashes), + block_hashes_evicted=[], + ) + + +def test_offloading_connector(request_runner): + offloaded_block_size = 12 + gpu_block_size = 4 + num_gpu_blocks = 100 + block_size_factor = offloaded_block_size // gpu_block_size + + runner = request_runner(offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks) + + # 3 blocks, store just the middle block (skip first and last) + # blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8] + runner.new_request(token_ids=[0] * offloaded_block_size * 3) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output(list(block_hashes)[1:2]) + runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) + + # add block missing 1 token -> no offload + runner.run(decoded_tokens=[0] * (offloaded_block_size - 1)) + runner.manager.prepare_store.assert_not_called() + + # +1 token -> single block, fail prepare_store + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: None + runner.run(decoded_tokens=[0]) + runner.manager.prepare_store.assert_called() + + # 1 more block, now set block_hashes_to_store = [] + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[0] * offloaded_block_size) + + # 1 more block, now check touch was called with all 6 blocks + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output(block_hashes) + runner.run(decoded_tokens=[0] * offloaded_block_size, + expected_stored_gpu_block_indexes=(15, 16, 17)) + runner.manager.touch.assert_called() + block_hashes1 = list(runner.manager.touch.call_args.args[0]) + assert len(block_hashes1) == 6 + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # create a new request differing only on the last token + runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) + runner.run(decoded_tokens=[0], + expected_stored_gpu_block_indexes=tuple( + range(6 * block_size_factor))) + runner.manager.touch.assert_called() + block_hashes2 = list(runner.manager.touch.call_args.args[0]) + assert len(block_hashes2) == 6 + + # verify hashes are the same, except for the last block + assert block_hashes1[:5] == block_hashes2[:5] + assert block_hashes1[5] != block_hashes2[5] + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # full_block_tokens - num_computed_tokens < offloaded_block_size + runner.new_request(token_ids=[0] * gpu_block_size + [1] * + (offloaded_block_size - gpu_block_size)) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_not_called() + + # single block lookup with no hits + runner.new_request(token_ids=[1] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_called() + assert len(list(runner.manager.lookup.call_args.args[0])) == 1 + + # single block lookup with a hit + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.manager.lookup.return_value = 1 + runner.run(decoded_tokens=[EOS_TOKEN_ID], + expected_loaded_gpu_block_indexes=(0, 1, 2)) + + # single block lookup with a hit in a middle block + runner.new_request(token_ids=[0] * offloaded_block_size * 2 + + [1] * offloaded_block_size) + runner.manager.prepare_store.side_effect = \ + lambda block_hashes: generate_store_output([]) + runner.manager.lookup.return_value = 1 + runner.run(decoded_tokens=[EOS_TOKEN_ID], + expected_loaded_gpu_block_indexes=(3, 4, 5)) + + # test take_events + def to_hashes(int_hashes: list[int]) -> list[BlockHash]: + return [BlockHash(str(i).encode()) for i in int_hashes] + + def take_events() -> Iterable[OffloadingEvent]: + yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]), + block_size=16, + medium="A", + removed=False) + yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]), + block_size=32, + medium="B", + removed=True) + + runner.manager.take_events.side_effect = take_events + events = list(runner.scheduler_connector.take_events()) + assert len(events) == 2 + event = events[0] + assert isinstance(event, BlockStored) + assert event.block_hashes == to_hashes([1, 2, 3]) + assert event.block_size == 16 + assert event.medium == "A" + assert event.token_ids == [] + assert event.parent_block_hash is None + assert event.lora_id is None + event = events[1] + assert isinstance(event, BlockRemoved) + assert event.block_hashes == to_hashes([4, 5, 6]) + assert event.medium == "B" diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 0cae1c7bc0..de52668e3d 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -176,6 +176,7 @@ def create_model_runner_output( finished_sending: Optional[list[str]] = None, finished_recving: Optional[list[str]] = None, use_eos: bool = False, + token_id: int = 0, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" @@ -184,7 +185,7 @@ def create_model_runner_output( req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} # Make sampled tokens. - sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token = EOS_TOKEN_ID if use_eos else token_id sampled_token_ids = [[sampled_token] for _ in req_ids] kv_connector_output = None if ( diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 670f9c26b2..873f130ed8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -106,3 +106,8 @@ KVConnectorFactory.register_connector( "MultiConnector", "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", "MultiConnector") + +KVConnectorFactory.register_connector( + "OffloadingConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", + "OffloadingConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py new file mode 100644 index 0000000000..c23efa6045 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from itertools import islice +from typing import Any, Optional + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_offload.abstract import OffloadingManager +from vllm.v1.kv_offload.factory import OffloadingSpecFactory +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import Request + +ReqId = str + +logger = init_logger(__name__) + + +@dataclass +class OffloadingConnectorMetadata(KVConnectorMetadata): + reqs_to_load: dict[ReqId, TransferSpec] + reqs_to_store: dict[ReqId, TransferSpec] + + +class OffloadingConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config, role) + + spec = OffloadingSpecFactory.create_spec(vllm_config) + + self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None + self.connector_worker: Optional[OffloadingConnectorWorker] = None + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = OffloadingConnectorScheduler(spec) + elif role == KVConnectorRole.WORKER: + self.connector_worker = OffloadingConnectorWorker(spec) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + OffloadingConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + OffloadingConnectorMetadata) + self.connector_worker.start_store_kv(self._connector_metadata) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def update_connector_output(self, connector_output: KVConnectorOutput): + assert self.connector_scheduler is not None + self.connector_scheduler.update_connector_output(connector_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + def take_events(self) -> Iterable[KVCacheEvent]: + assert self.connector_scheduler is not None + return self.connector_scheduler.take_events() + + +class OffloadingConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, spec: OffloadingSpec): + self.gpu_block_size = spec.gpu_block_size + self.offloaded_block_size = spec.offloaded_block_size + self.block_size_factor = (self.offloaded_block_size // + self.gpu_block_size) + self.manager: OffloadingManager = spec.get_manager() + + self._requests: dict[ReqId, Request] = {} + # list of GPU block IDs per request + self._request_block_ids: dict[ReqId, list[int]] = {} + # requests to load for the current scheduler step + self._reqs_to_load: dict[ReqId, TransferSpec] = {} + # request blocks are stored in order + # index of next block (of size offloaded_block_size) to offload + self._next_stored_block_idx: dict[ReqId, int] = {} + + # request ID -> set(block hashes being stored/load) + self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set) + + def _get_block_hashes( + self, + req: Request, + start_idx: int = 0, + end_idx: Optional[int] = None, + ) -> Iterable[BlockHash]: + return islice( + req.block_hashes, + self.block_size_factor * start_idx + self.block_size_factor - 1, + self.block_size_factor * end_idx if end_idx else None, + self.block_size_factor) + + def get_num_new_matched_tokens( + self, request: Request, + num_computed_tokens: int) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded beyond the + num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded beyond what is + already computed. + - `True` if tokens will be loaded asynchronously + (between scheduler steps). + """ + num_blocks = request.num_tokens // self.offloaded_block_size + + assert (len(request.block_hashes) // + self.block_size_factor == num_blocks) + block_hashes = self._get_block_hashes(request) + + self.manager.touch(block_hashes) + + full_block_tokens = self.offloaded_block_size * num_blocks + if full_block_tokens - num_computed_tokens < self.offloaded_block_size: + # we can load less than a block, skip + return 0, False + + start_block_idx = num_computed_tokens // self.offloaded_block_size + hits = self.manager.lookup( + self._get_block_hashes(request, start_idx=start_block_idx)) + if hits == 0: + return 0, False + + num_hit_tokens = (self.offloaded_block_size * + (start_block_idx + hits) - num_computed_tokens) + logger.debug( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + request.request_id, + num_hit_tokens, + num_computed_tokens, + ) + if num_hit_tokens < self.offloaded_block_size: + return 0, False + + return num_hit_tokens, True + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_external_tokens: int): + self._requests[request.request_id] = request + # the block ids are updated in _get_reqs_to_store + self._request_block_ids[request.request_id] = [] + + if num_external_tokens == 0: + return + + block_groups = blocks.get_block_ids() + block_ids = block_groups[0] + + num_computed_gpu_blocks = sum(block.block_hash is not None + for block in blocks.blocks[0]) + num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size + full_block_tokens = num_computed_tokens + num_external_tokens + assert full_block_tokens % self.offloaded_block_size == 0 + + num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks + assert (num_external_tokens == num_pending_gpu_blocks * + self.gpu_block_size) + + start_block_idx = num_computed_tokens // self.offloaded_block_size + num_blocks = full_block_tokens // self.offloaded_block_size + + assert (len(request.block_hashes) // self.block_size_factor + >= num_blocks) + block_hashes = self._get_block_hashes(request, + start_idx=start_block_idx, + end_idx=num_blocks) + + src_spec = self.manager.prepare_load(block_hashes) + dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:]) + + block_hashes = self._get_block_hashes(request, + start_idx=start_block_idx, + end_idx=num_blocks) + + self._reqs_to_load[request.request_id] = (src_spec, dst_spec) + self._reqs_being_loaded[request.request_id].update(block_hashes) + self._next_stored_block_idx[request.request_id] = num_blocks + + def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): + reqs_to_store: dict[ReqId, TransferSpec] = {} + # iterate over both new and cached requests + for req_id, new_block_id_groups, preempted in yield_req_data( + scheduler_output): + + if preempted: + self._request_block_ids[req_id] = [] + + if new_block_id_groups: + new_block_ids = new_block_id_groups[0] + self._request_block_ids[req_id] += new_block_ids + + block_ids = self._request_block_ids[req_id] + + req = self._requests[req_id] + new_tokens = scheduler_output.num_scheduled_tokens[req_id] + total_tokens = req.num_computed_tokens + new_tokens + num_blocks = total_tokens // self.offloaded_block_size + start_block_idx = self._next_stored_block_idx.get(req_id, 0) + num_new_blocks = num_blocks - start_block_idx + + if num_new_blocks <= 0: + continue + + num_gpu_blocks = num_blocks * self.block_size_factor + assert len(req.block_hashes) >= num_gpu_blocks + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks) + store_output = self.manager.prepare_store(new_block_hashes) + if store_output is None: + logger.warning("Cannot store %s blocks", num_new_blocks) + break + + self._next_stored_block_idx[req_id] = num_blocks + + if not store_output.block_hashes_to_store: + continue + block_hashes_to_store = set(store_output.block_hashes_to_store) + + block_hashes = self._get_block_hashes(req, end_idx=num_blocks) + self.manager.touch(block_hashes) + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks) + dst_spec = store_output.store_spec + src_block_ids: list[int] = [] + for idx, blk_hash in enumerate(new_block_hashes): + if blk_hash not in block_hashes_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * self.block_size_factor + for i in range(self.block_size_factor): + src_block_ids.append(block_ids[gpu_block_idx + i]) + src_spec = GPULoadStoreSpec(src_block_ids) + + reqs_to_store[req_id] = (src_spec, dst_spec) + self._reqs_being_stored[req_id] |= block_hashes_to_store + + logger.debug( + "Request %s offloading %s blocks starting from block #%d", + req_id, + len(block_hashes_to_store), + start_block_idx, + ) + + return reqs_to_store + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + meta = OffloadingConnectorMetadata( + reqs_to_load=self._reqs_to_load, + reqs_to_store=self._get_reqs_to_store(scheduler_output)) + self._reqs_to_load = {} + return meta + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + for req_id in connector_output.finished_sending or []: + block_hashes = self._reqs_being_stored.pop(req_id, None) + if block_hashes: + self.manager.complete_store(block_hashes) + + for req_id in connector_output.finished_recving or []: + block_hashes = self._reqs_being_loaded.pop(req_id, None) + if block_hashes: + self.manager.complete_load(block_hashes) + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + self._requests.pop(req_id, None) + self._request_block_ids.pop(req_id, None) + self._next_stored_block_idx.pop(req_id, None) + + request_being_stored = req_id in self._reqs_being_stored + return request_being_stored, None + + def take_events(self) -> Iterable[KVCacheEvent]: + """Take the KV cache events from the connector. + + Returns: + A list of KV cache events. + """ + for event in self.manager.take_events(): + if event.removed: + yield BlockRemoved(block_hashes=event.block_hashes, + medium=event.medium) + else: + yield BlockStored(block_hashes=event.block_hashes, + parent_block_hash=None, + token_ids=[], + lora_id=None, + block_size=event.block_size, + medium=event.medium) + + +class OffloadingConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, spec: OffloadingSpec): + self.spec = spec + self.worker = OffloadingWorker() + + self._job_counter = 0 + + # req_id -> (job_id, store) + self._jobs: dict[int, tuple[ReqId, bool]] = {} + # req_id -> active job IDs + self._load_job: dict[ReqId, int] = {} + # req_id -> set(active job IDs) + self._store_jobs = defaultdict[ReqId, set[int]](set) + + self._finished_reqs_waiting_for_store: set[ReqId] = set() + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter = job_id + 1 + return job_id + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)): + self.worker.register_handler(src_cls, dst_cls, handler) + + def start_load_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_load.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, False) + assert req_id not in self._load_job + self._load_job[req_id] = job_id + assert self.worker.transfer_async(job_id, transfer_spec) + + def start_store_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_store.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, True) + self._store_jobs[req_id].add(job_id) + assert self.worker.transfer_async(job_id, transfer_spec) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + Returns a list of request IDs that finished loading or storing. + + Returns: + ids of requests that have finished asynchronous transfer + tuple of (sending/saving ids, recving/loading ids). + """ + finished_sending = set() + finished_recving = set() + for job_id, success in self.worker.get_finished(): + # we currently do not support job failures + assert success + req_id, store = self._jobs.pop(job_id) + if store: + req_jobs = self._store_jobs[req_id] + req_jobs.remove(job_id) + if req_jobs: + continue + + if req_id in self._finished_reqs_waiting_for_store: + self._finished_reqs_waiting_for_store.remove(req_id) + finished_sending.add(req_id) + del self._store_jobs[req_id] + else: + req_job = self._load_job[req_id] + assert job_id == req_job + del self._load_job[req_id] + finished_recving.add(req_id) + + for req_id in finished_req_ids: + pending_req_jobs = self._store_jobs.get(req_id) + if pending_req_jobs: + self._finished_reqs_waiting_for_store.add(req_id) + elif pending_req_jobs is not None: + finished_sending.add(req_id) + del self._store_jobs[req_id] + + return finished_sending, finished_recving + + +def yield_req_data( + scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: + """ + Yields: + (req_id, new_block_id_groups, preempted) + """ + # new requests + for req_data in scheduler_output.scheduled_new_reqs: + yield req_data.req_id, req_data.block_ids, False + + # cached requests + cached_reqs = scheduler_output.scheduled_cached_reqs + yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids, + cached_reqs.resumed_from_preemption) diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py new file mode 100644 index 0000000000..6365ab4a6d --- /dev/null +++ b/vllm/v1/kv_offload/factory.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +from typing import TYPE_CHECKING, Callable + +from vllm.logger import init_logger +from vllm.v1.kv_offload.spec import OffloadingSpec + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpecFactory: + _registry: dict[str, Callable[[], type[OffloadingSpec]]] = {} + + @classmethod + def register_spec(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a spec with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[OffloadingSpec]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_spec( + cls, + config: "VllmConfig", + ) -> OffloadingSpec: + kv_transfer_config = config.kv_transfer_config + assert kv_transfer_config is not None + extra_config = kv_transfer_config.kv_connector_extra_config + spec_name = extra_config.get("spec_name", "CPUOffloadingSpec") + if spec_name in cls._registry: + spec_cls = cls._registry[spec_name]() + else: + spec_module_path = extra_config.get("spec_module_path") + if spec_module_path is None: + raise ValueError(f"Unsupported spec type: {spec_name}") + spec_module = importlib.import_module(spec_module_path) + spec_cls = getattr(spec_module, spec_name) + assert issubclass(spec_cls, OffloadingSpec) + logger.info("Creating offloading spec with name: %s", spec_name) + return spec_cls(config) + + +# Register various specs here. diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py new file mode 100644 index 0000000000..ed23d5e519 --- /dev/null +++ b/vllm/v1/kv_offload/spec.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpec(ABC): + """Spec for an offloading connector""" + + def __init__(self, vllm_config: "VllmConfig"): + logger.warning( + "Initializing OffloadingSpec. This API is experimental and " + "subject to change in the future as we iterate the design.") + self.vllm_config = vllm_config + + kv_transfer_config = vllm_config.kv_transfer_config + assert kv_transfer_config is not None + self.extra_config = kv_transfer_config.kv_connector_extra_config + + self.gpu_block_size = vllm_config.cache_config.block_size + self.offloaded_block_size = int( + self.extra_config.get("block_size", self.gpu_block_size)) + + assert self.offloaded_block_size % self.gpu_block_size == 0 + + @abstractmethod + def get_manager(self) -> OffloadingManager: + """ + Get an OffloadingManager that will be used + by the scheduler-side offloading connector to track + offloaded blocks and manage evictions. + """ + pass + + @abstractmethod + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + OffloadingHandler]]: + """ + Get offloading handlers along with their respective src and dst types. + + Args: + kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor. + + Yields: + Tuples of (src_type, dst_type, offloading_handler). + """ + pass