v1: Support KV events from connectors (#19737)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
@ -40,16 +40,21 @@ class KVCacheEvent(
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
MEDIUM_GPU = "GPU"
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
@ -19,6 +19,8 @@ The class provides the following primitives:
|
||||
Returns whether KV cache should be freed now or will be
|
||||
freed asynchronously and optionally returns KV transfer
|
||||
params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
@ -34,6 +36,7 @@ The class provides the following primitives:
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
|
@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
|
@ -4,8 +4,9 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||
BlockRemoved, BlockStored,
|
||||
KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||
@ -156,6 +157,7 @@ class BlockPool:
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.id
|
||||
if request.lora_request else None,
|
||||
medium=MEDIUM_GPU,
|
||||
))
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
@ -218,7 +220,8 @@ class BlockPool:
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
|
||||
medium=MEDIUM_GPU))
|
||||
return True
|
||||
|
||||
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
|
@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# collect KV cache events from KV cache manager
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
Reference in New Issue
Block a user