v1: Support KV events from connectors (#19737)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-09-01 04:13:21 +03:00
committed by GitHub
parent 752d2e1c36
commit 14b4326b94
6 changed files with 44 additions and 3 deletions

View File

@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent): class AllBlocksCleared(KVCacheEvent):

View File

@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events""" """Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
parent_block_hash: Optional[int] parent_block_hash: Optional[int]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent): class AllBlocksCleared(KVCacheEvent):

View File

@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer freed asynchronously and optionally returns KV transfer
params. params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
@ -34,6 +36,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch import torch
@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
""" """
return False, None return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod @classmethod
def get_required_kvcache_layout( def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]: cls, vllm_config: "VllmConfig") -> Optional[str]:

View File

@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
for c in self._connectors:
yield from c.take_events()
@classmethod @classmethod
def get_required_kvcache_layout( def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]: cls, vllm_config: "VllmConfig") -> Optional[str]:

View File

@ -4,8 +4,9 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
BlockStored, KVCacheEvent) BlockRemoved, BlockStored,
KVCacheEvent)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock) FreeKVCacheBlockQueue, KVCacheBlock)
@ -156,6 +157,7 @@ class BlockPool:
block_size=block_size, block_size=block_size,
lora_id=request.lora_request.id lora_id=request.lora_request.id
if request.lora_request else None, if request.lora_request else None,
medium=MEDIUM_GPU,
)) ))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
@ -218,7 +220,8 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is # we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group. # enabled, so there is only one group.
self.kv_event_queue.append( self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()])) BlockRemoved(block_hashes=[block_hash.get_hash_value()],
medium=MEDIUM_GPU))
return True return True
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:

View File

@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output) meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events() events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events: if events:
batch = KVEventBatch(ts=time.time(), events=events) batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch) self.kv_event_publisher.publish(batch)