mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[KV offload][4/N] Offloading KV connector (#22595)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
@ -0,0 +1,505 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector, OffloadingConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
OffloadingManager, PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "Mock"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_hashes)
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
self.completed_transfers = []
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
return True
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
|
||||
block_hashes))
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
return specs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferSummary:
|
||||
gpu_block_indices: list[int]
|
||||
offload_addresses: list[Any]
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
|
||||
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
|
||||
num_gpu_blocks: int):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(block_size=gpu_block_size,
|
||||
max_num_batched_tokens=1000)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path":
|
||||
"tests.v1.kv_connector.unit.test_offloading_connector",
|
||||
"block_size": offloaded_block_size,
|
||||
})
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(vllm_config,
|
||||
num_blocks=num_gpu_blocks)
|
||||
self.worker_connector = OffloadingConnector(vllm_config,
|
||||
KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_kv_caches(
|
||||
kv_caches={"a": torch.empty(0)})
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, OffloadingConnector)
|
||||
self.scheduler_connector: OffloadingConnector = scheduler_connector
|
||||
|
||||
# extract mocked OffloadingManager of scheduler connector
|
||||
connector_scheduler = scheduler_connector.connector_scheduler
|
||||
assert connector_scheduler is not None
|
||||
manager = connector_scheduler.manager
|
||||
assert isinstance(manager, MagicMock)
|
||||
self.manager: MagicMock = manager
|
||||
|
||||
assert connector_scheduler.gpu_block_size == gpu_block_size
|
||||
assert connector_scheduler.offloaded_block_size == offloaded_block_size
|
||||
|
||||
# extract OffloadingSpec of worker_connector
|
||||
connector_worker = self.worker_connector.connector_worker
|
||||
assert connector_worker is not None
|
||||
offloading_spec = connector_worker.spec
|
||||
assert isinstance(offloading_spec, MockOffloadingSpec)
|
||||
self.offloading_spec: MockOffloadingSpec = offloading_spec
|
||||
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=1000),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in (
|
||||
self.offloading_spec.get_completed_transfers()):
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(
|
||||
self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = (len(offload_addresses) -
|
||||
len(gpu_block_indices))
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[
|
||||
remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in (self.scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks.values()):
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
self._update_gpu_block_idx()
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(
|
||||
kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(
|
||||
kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=list(finished_sending),
|
||||
finished_recving=list(finished_recving),
|
||||
token_id=token_id)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
assert not self.scheduler.running
|
||||
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert (
|
||||
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert (
|
||||
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
yield runner_factory # pass factory to the test
|
||||
|
||||
|
||||
def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
block_hashes = list(block_hashes)
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=list(block_hashes),
|
||||
store_spec=MockLoadStoreSpec(block_hashes),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(
|
||||
range(6 * block_size_factor)))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
|
||||
# verify hashes are the same, except for the last block
|
||||
assert block_hashes1[:5] == block_hashes2[:5]
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
|
||||
(offloaded_block_size - gpu_block_size))
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2))
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
|
||||
[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
|
||||
block_size=16,
|
||||
medium="A",
|
||||
removed=False)
|
||||
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
|
||||
block_size=32,
|
||||
medium="B",
|
||||
removed=True)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
assert len(events) == 2
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == to_hashes([1, 2, 3])
|
||||
assert event.block_size == 16
|
||||
assert event.medium == "A"
|
||||
assert event.token_ids == []
|
||||
assert event.parent_block_hash is None
|
||||
assert event.lora_id is None
|
||||
event = events[1]
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
@ -176,6 +176,7 @@ def create_model_runner_output(
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
@ -184,7 +185,7 @@ def create_model_runner_output(
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
|
@ -106,3 +106,8 @@ KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector")
|
||||
|
@ -0,0 +1,485 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config, role)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
||||
|
||||
self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[OffloadingConnectorWorker] = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = (self.offloaded_block_size //
|
||||
self.gpu_block_size)
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request,
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) //
|
||||
self.block_size_factor == num_blocks)
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx))
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (self.offloaded_block_size *
|
||||
(start_block_idx + hits) - num_computed_tokens)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_external_tokens: int):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(block.block_hash is not None
|
||||
for block in blocks.blocks[0])
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert (num_external_tokens == num_pending_gpu_blocks *
|
||||
self.gpu_block_size)
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) // self.block_size_factor
|
||||
>= num_blocks)
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_loaded[request.request_id].update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(
|
||||
scheduler_output):
|
||||
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning("Cannot store %s blocks", num_new_blocks)
|
||||
break
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output))
|
||||
self._reqs_to_load = {}
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes,
|
||||
medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for job_id, success in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
assert success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids,
|
||||
cached_reqs.resumed_from_preemption)
|
53
vllm/v1/kv_offload/factory.py
Normal file
53
vllm/v1/kv_offload/factory.py
Normal file
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpecFactory:
|
||||
_registry: dict[str, Callable[[], type[OffloadingSpec]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_spec(cls, name: str, module_path: str,
|
||||
class_name: str) -> None:
|
||||
"""Register a spec with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[OffloadingSpec]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_spec(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
) -> OffloadingSpec:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
spec_name = extra_config.get("spec_name", "CPUOffloadingSpec")
|
||||
if spec_name in cls._registry:
|
||||
spec_cls = cls._registry[spec_name]()
|
||||
else:
|
||||
spec_module_path = extra_config.get("spec_module_path")
|
||||
if spec_module_path is None:
|
||||
raise ValueError(f"Unsupported spec type: {spec_name}")
|
||||
spec_module = importlib.import_module(spec_module_path)
|
||||
spec_cls = getattr(spec_module, spec_name)
|
||||
assert issubclass(spec_cls, OffloadingSpec)
|
||||
logger.info("Creating offloading spec with name: %s", spec_name)
|
||||
return spec_cls(config)
|
||||
|
||||
|
||||
# Register various specs here.
|
61
vllm/v1/kv_offload/spec.py
Normal file
61
vllm/v1/kv_offload/spec.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpec(ABC):
|
||||
"""Spec for an offloading connector"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
logger.warning(
|
||||
"Initializing OffloadingSpec. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design.")
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
self.gpu_block_size = vllm_config.cache_config.block_size
|
||||
self.offloaded_block_size = int(
|
||||
self.extra_config.get("block_size", self.gpu_block_size))
|
||||
|
||||
assert self.offloaded_block_size % self.gpu_block_size == 0
|
||||
|
||||
@abstractmethod
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
"""
|
||||
Get an OffloadingManager that will be used
|
||||
by the scheduler-side offloading connector to track
|
||||
offloaded blocks and manage evictions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_handlers(
|
||||
self, kv_caches: dict[str, torch.Tensor]
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
"""
|
||||
Get offloading handlers along with their respective src and dst types.
|
||||
|
||||
Args:
|
||||
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
|
||||
|
||||
Yields:
|
||||
Tuples of (src_type, dst_type, offloading_handler).
|
||||
"""
|
||||
pass
|
Reference in New Issue
Block a user