mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[KV offload][1/N] Introduce an offloading component (#19848)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@ -280,6 +280,7 @@ steps:
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/offloading
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
|
152
tests/v1/offloading/test_worker.py
Normal file
152
tests/v1/offloading/test_worker.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
|
||||
OffloadingWorker, TransferResult,
|
||||
TransferSpec)
|
||||
|
||||
|
||||
class LoadStoreSpec1(LoadStoreSpec):
|
||||
|
||||
def __init__(self,
|
||||
submit_success: bool = True,
|
||||
async_success: bool = True,
|
||||
exception: bool = False):
|
||||
self.finished = False
|
||||
self.submit_success = submit_success
|
||||
self.async_success = async_success
|
||||
self.exception = exception
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "1"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class LoadStoreSpec2(LoadStoreSpec):
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "2"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class OffloadingHandler1To2(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec1)
|
||||
assert isinstance(dst, LoadStoreSpec2)
|
||||
|
||||
if src.exception:
|
||||
raise Exception("An expected exception. Don't worry!")
|
||||
if not src.submit_success:
|
||||
return False
|
||||
|
||||
self.transfers[job_id] = src
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
class OffloadingHandler2To1(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec2)
|
||||
assert isinstance(dst, LoadStoreSpec1)
|
||||
|
||||
self.transfers[job_id] = dst
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
def test_offloading_worker():
|
||||
"""
|
||||
Tests OffloadingWorker with 2 handlers.
|
||||
One handler performs 1->2 transfers, and the other handles 2->1.
|
||||
"""
|
||||
worker = OffloadingWorker()
|
||||
handler1to2 = OffloadingHandler1To2()
|
||||
handler2to1 = OffloadingHandler2To1()
|
||||
worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
|
||||
worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)
|
||||
|
||||
# 1st transfer 1->2 (exception)
|
||||
src1 = LoadStoreSpec1(exception=True)
|
||||
dst1 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(1, (src1, dst1))
|
||||
|
||||
# 2ed transfer 1->2 (failure to submit)
|
||||
src2 = LoadStoreSpec1(submit_success=False)
|
||||
dst2 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(2, (src2, dst2))
|
||||
|
||||
# 3rd transfer 1->2 (failure)
|
||||
src3 = LoadStoreSpec1(async_success=False)
|
||||
dst3 = LoadStoreSpec2()
|
||||
assert worker.transfer_async(3, (src3, dst3))
|
||||
|
||||
# 4th transfer 1->2 (success)
|
||||
src4 = LoadStoreSpec1()
|
||||
dst4 = LoadStoreSpec2()
|
||||
worker.transfer_async(4, (src4, dst4))
|
||||
assert set(handler1to2.transfers.keys()) == {3, 4}
|
||||
|
||||
# 5th transfer 2->1
|
||||
src5 = LoadStoreSpec2()
|
||||
dst5 = LoadStoreSpec1()
|
||||
worker.transfer_async(5, (src5, dst5))
|
||||
assert set(handler2to1.transfers.keys()) == {5}
|
||||
|
||||
# no transfer completed yet
|
||||
assert worker.get_finished() == []
|
||||
|
||||
# complete 3rd, 4th
|
||||
src3.finished = True
|
||||
src4.finished = True
|
||||
|
||||
# 6th transfer 1->2
|
||||
src6 = LoadStoreSpec1()
|
||||
dst6 = LoadStoreSpec2()
|
||||
worker.transfer_async(6, (src6, dst6))
|
||||
|
||||
# 7th transfer 2->1
|
||||
src7 = LoadStoreSpec2()
|
||||
dst7 = LoadStoreSpec1()
|
||||
worker.transfer_async(7, (src7, dst7))
|
||||
|
||||
# 6th and 7th transfers started
|
||||
assert 6 in handler1to2.transfers
|
||||
assert 7 in handler2to1.transfers
|
||||
|
||||
# verify result of 3rd and 4th transfers
|
||||
assert (sorted(worker.get_finished()) == [(3, False), (4, True)])
|
||||
|
||||
# complete 6th and 7th transfers
|
||||
src6.finished = True
|
||||
dst7.finished = True
|
||||
assert (sorted(worker.get_finished()) == [(6, True), (7, True)])
|
165
vllm/v1/offloading/abstract.py
Normal file
165
vllm/v1/offloading/abstract.py
Normal file
@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
OffloadingManager class for managing KV data offloading in vLLM v1
|
||||
|
||||
This class runs in the scheduler, tracks which blocks are offloaded
|
||||
and their address.
|
||||
|
||||
The class provides the following primitives:
|
||||
lookup() - find the length of the maximal series of blocks,
|
||||
starting from the first one, that are all offloaded.
|
||||
prepare_load() - prepare given blocks to be read.
|
||||
The given blocks will be protected from eviction.
|
||||
This function returns a LoadSpec which encapsulates
|
||||
information required for performing the load.
|
||||
touch() - marks the give blocks as recently used. Can be used
|
||||
to track block's LRU. This function is separated from the
|
||||
prepare_load function to allow setting block recency even
|
||||
for blocks which do not need reading from the cache, such as
|
||||
blocks that are cached by the GPU prefix cache.
|
||||
complete_load() - mark blocks which were previously prepared to be
|
||||
loaded as done loading. This is to re-allow their eviction.
|
||||
prepare_store() - prepare the given blocks to be written.
|
||||
Returns a StoreSpec encapsulating offloading information,
|
||||
as well as a list of blocks that were evicted as a result.
|
||||
complete_store() - marks a previous store as completed.
|
||||
Following this call, the given blocks will become loadable.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class LoadStoreSpec(ABC):
|
||||
"""
|
||||
Abstract metadata that encapsulates information allowing a worker
|
||||
to load, and optionally also to store, blocks of KV data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def medium() -> str:
|
||||
"""
|
||||
Returns a string representation of the medium type
|
||||
this store/load targets.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareStoreOutput:
|
||||
block_hashes_to_store: list[BlockHash]
|
||||
store_spec: LoadStoreSpec
|
||||
block_hashes_evicted: list[BlockHash]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingEvent:
|
||||
block_hashes: list[BlockHash]
|
||||
block_size: int
|
||||
medium: str
|
||||
# True if blocks are removed, False if stored
|
||||
removed: bool
|
||||
|
||||
|
||||
class OffloadingManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
"""
|
||||
Finds the length of the maximal series of blocks, starting from the
|
||||
first one, that are all offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks to lookup.
|
||||
|
||||
Returns:
|
||||
An integer representing the maximal number of blocks that
|
||||
are currently offloaded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
"""
|
||||
Prepare the given blocks to be read.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_load is called.
|
||||
It assumes all given blocks are offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A LoadStoreSpec that can be used by a worker to locate and load
|
||||
the actual offloaded KV data.
|
||||
"""
|
||||
pass
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Mark the given blocks as recently used.
|
||||
This could in practice mean moving them to the end of an LRU list.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Marks previous blocks that were prepared to load as done loading.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def prepare_store(
|
||||
self,
|
||||
block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
|
||||
"""
|
||||
Prepare the given blocks to be offloaded.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_store is called.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A PrepareStoreOutput indicating which blocks need storing,
|
||||
where to store them (LoadStoreSpec), and list of blocks that
|
||||
were evicted as a result.
|
||||
None is returned if the blocks cannot be stored.
|
||||
"""
|
||||
pass
|
||||
|
||||
def complete_store(self,
|
||||
block_hashes: Iterable[BlockHash],
|
||||
success: bool = True):
|
||||
"""
|
||||
Marks blocks which were previously prepared to be stored, as stored.
|
||||
Following this call, the blocks become loadable.
|
||||
If if_success is False, blocks that were not marked as stored will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
success: whether the blocks were stored successfully.
|
||||
"""
|
||||
return
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
"""
|
||||
Take the offloading events from the manager.
|
||||
|
||||
Yields:
|
||||
New OffloadingEvents collected since the last call.
|
||||
"""
|
||||
return ()
|
39
vllm/v1/offloading/mediums.py
Normal file
39
vllm/v1/offloading/mediums.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
|
||||
|
||||
class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):
|
||||
"""
|
||||
Spec for loading/storing KV blocks from given block numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, block_ids: list[int]):
|
||||
self.block_ids = np.array(block_ids, dtype=np.int64)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_ids)
|
||||
|
||||
|
||||
class GPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to GPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "GPU"
|
||||
|
||||
|
||||
class CPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to CPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "CPU"
|
142
vllm/v1/offloading/worker/worker.py
Normal file
142
vllm/v1/offloading/worker/worker.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
|
||||
# a single transfer spec (src_blocks_spec, dst_blocks_spec)
|
||||
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]
|
||||
# transfers are forwarded to workers by (src_medium, dst_medium)
|
||||
TransferType = tuple[str, str]
|
||||
# transfer result (job_id, success)
|
||||
TransferResult = tuple[int, bool]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingHandler(ABC):
|
||||
"""
|
||||
OffloadingHandler class for managing asynchronous KV data transfers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, and allows
|
||||
collecting back completion statuses.
|
||||
|
||||
The class provides the following primitives:
|
||||
transfer_async() - kicks off a new transfer job
|
||||
get_finished() - returns a list of newly finished job IDs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OffloadingWorker:
|
||||
"""
|
||||
OffloadingWorker class for managing asynchronous KV data transfers
|
||||
using multiple OffloadingHandlers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, by delegating
|
||||
to one of its registered OffloadingHandlers, based on the transfer type.
|
||||
|
||||
The class provides the following primitives:
|
||||
register_handler() - registers a new handler to handle
|
||||
a specific transfer type
|
||||
transfer_async() - kicks off a new transfer job
|
||||
using one of the registered handlers.
|
||||
get_finished() - returns a list of newly finished job IDs
|
||||
from all handlers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers: set[OffloadingHandler] = set()
|
||||
self.transfer_type_to_handler: dict[TransferType,
|
||||
OffloadingHandler] = {}
|
||||
|
||||
def register_handler(self, src_cls: type[LoadStoreSpec],
|
||||
dst_cls: type[LoadStoreSpec],
|
||||
handler: OffloadingHandler) -> None:
|
||||
"""
|
||||
Registers a new handler.
|
||||
|
||||
Args:
|
||||
src_cls: the source type of transfers handled by this handler.
|
||||
dst_cls: the destination type of transfers handled by this handler.
|
||||
handler: the handler that will handle transfers.
|
||||
"""
|
||||
transfer_type = (src_cls.medium(), dst_cls.medium())
|
||||
assert transfer_type not in self.transfer_type_to_handler
|
||||
self.handlers.add(handler)
|
||||
self.transfer_type_to_handler[transfer_type] = handler
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
src, dst = spec
|
||||
transfer_type = (src.medium(), dst.medium())
|
||||
handler = self.transfer_type_to_handler.get(transfer_type)
|
||||
assert handler is not None
|
||||
|
||||
try:
|
||||
success = handler.transfer_async(job_id, spec)
|
||||
except Exception as e:
|
||||
logger.warning("Exception in %r transfer %d: %r",
|
||||
transfer_type,
|
||||
job_id,
|
||||
e,
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
if not success:
|
||||
logger.warning("Failed to submit %r transfer %d", transfer_type,
|
||||
job_id)
|
||||
else:
|
||||
logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id,
|
||||
spec)
|
||||
|
||||
return success
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
finished = []
|
||||
for handler in self.handlers:
|
||||
finished.extend(handler.get_finished())
|
||||
return finished
|
Reference in New Issue
Block a user