From 14b4326b9470c098d537cce3834033a7f3b2c386 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Mon, 1 Sep 2025 04:13:21 +0300 Subject: [PATCH] v1: Support KV events from connectors (#19737) Signed-off-by: Or Ozeri --- examples/online_serving/kv_events_subscriber.py | 2 ++ vllm/distributed/kv_events.py | 5 +++++ .../distributed/kv_transfer/kv_connector/v1/base.py | 13 +++++++++++++ .../kv_transfer/kv_connector/v1/multi_connector.py | 6 ++++++ vllm/v1/core/block_pool.py | 9 ++++++--- vllm/v1/core/sched/scheduler.py | 12 ++++++++++++ 6 files changed, 44 insertions(+), 3 deletions(-) diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4..f238c66234 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent): token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 2d7935773d..37f8f72fa9 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -40,16 +40,21 @@ class KVCacheEvent( """Base class for all KV cache-related events""" +MEDIUM_GPU = "GPU" + + class BlockStored(KVCacheEvent): block_hashes: list[int] parent_block_hash: Optional[int] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): block_hashes: list[int] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 5601ee74be..2804003f5a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -19,6 +19,8 @@ The class provides the following primitives: Returns whether KV cache should be freed now or will be freed asynchronously and optionally returns KV transfer params. + take_events() - returns new KV events that were collected + by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -34,6 +36,7 @@ The class provides the following primitives: import enum from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC): """ return False, None + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + return () + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d3f6a226dc..65bcb4d93b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1): return async_saves > 0, kv_txfer_params + def take_events(self) -> Iterable[KVCacheEvent]: + for c in self._connectors: + yield from c.take_events() + @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index fdd96c3e95..b537cac8e1 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -4,8 +4,9 @@ from collections import defaultdict from collections.abc import Iterable from typing import Optional -from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, - BlockStored, KVCacheEvent) +from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, + BlockRemoved, BlockStored, + KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, FreeKVCacheBlockQueue, KVCacheBlock) @@ -156,6 +157,7 @@ class BlockPool: block_size=block_size, lora_id=request.lora_request.id if request.lora_request else None, + medium=MEDIUM_GPU, )) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -218,7 +220,8 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()], + medium=MEDIUM_GPU)) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 30a443499d..d4391b1c21 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch)