v1: Support KV events from connectors (#19737)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
|
|||||||
token_ids: list[int]
|
token_ids: list[int]
|
||||||
block_size: int
|
block_size: int
|
||||||
lora_id: Optional[int]
|
lora_id: Optional[int]
|
||||||
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[int]
|
||||||
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class AllBlocksCleared(KVCacheEvent):
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
|
@ -40,16 +40,21 @@ class KVCacheEvent(
|
|||||||
"""Base class for all KV cache-related events"""
|
"""Base class for all KV cache-related events"""
|
||||||
|
|
||||||
|
|
||||||
|
MEDIUM_GPU = "GPU"
|
||||||
|
|
||||||
|
|
||||||
class BlockStored(KVCacheEvent):
|
class BlockStored(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[int]
|
||||||
parent_block_hash: Optional[int]
|
parent_block_hash: Optional[int]
|
||||||
token_ids: list[int]
|
token_ids: list[int]
|
||||||
block_size: int
|
block_size: int
|
||||||
lora_id: Optional[int]
|
lora_id: Optional[int]
|
||||||
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[int]
|
||||||
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class AllBlocksCleared(KVCacheEvent):
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
|
@ -19,6 +19,8 @@ The class provides the following primitives:
|
|||||||
Returns whether KV cache should be freed now or will be
|
Returns whether KV cache should be freed now or will be
|
||||||
freed asynchronously and optionally returns KV transfer
|
freed asynchronously and optionally returns KV transfer
|
||||||
params.
|
params.
|
||||||
|
take_events() - returns new KV events that were collected
|
||||||
|
by the connector since the last call.
|
||||||
|
|
||||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||||
the Connector based on the metadata.
|
the Connector based on the metadata.
|
||||||
@ -34,6 +36,7 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||||
|
"""
|
||||||
|
Take the KV cache events from the connector.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
New KV cache events since the last call.
|
||||||
|
"""
|
||||||
|
return ()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_kvcache_layout(
|
def get_required_kvcache_layout(
|
||||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
import copy
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import KVTransferConfig, VllmConfig
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
return async_saves > 0, kv_txfer_params
|
return async_saves > 0, kv_txfer_params
|
||||||
|
|
||||||
|
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||||
|
for c in self._connectors:
|
||||||
|
yield from c.take_events()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_kvcache_layout(
|
def get_required_kvcache_layout(
|
||||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||||
|
@ -4,8 +4,9 @@ from collections import defaultdict
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||||
BlockStored, KVCacheEvent)
|
BlockRemoved, BlockStored,
|
||||||
|
KVCacheEvent)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||||
@ -156,6 +157,7 @@ class BlockPool:
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
lora_id=request.lora_request.id
|
lora_id=request.lora_request.id
|
||||||
if request.lora_request else None,
|
if request.lora_request else None,
|
||||||
|
medium=MEDIUM_GPU,
|
||||||
))
|
))
|
||||||
|
|
||||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||||
@ -218,7 +220,8 @@ class BlockPool:
|
|||||||
# we disable hybrid kv cache manager when kv cache event is
|
# we disable hybrid kv cache manager when kv cache event is
|
||||||
# enabled, so there is only one group.
|
# enabled, so there is only one group.
|
||||||
self.kv_event_queue.append(
|
self.kv_event_queue.append(
|
||||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
|
||||||
|
medium=MEDIUM_GPU))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||||
|
@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
|
|||||||
meta = self.connector.build_connector_meta(scheduler_output)
|
meta = self.connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
|
# collect KV cache events from KV cache manager
|
||||||
events = self.kv_cache_manager.take_events()
|
events = self.kv_cache_manager.take_events()
|
||||||
|
|
||||||
|
# collect KV cache events from connector
|
||||||
|
if self.connector is not None:
|
||||||
|
connector_events = self.connector.take_events()
|
||||||
|
if connector_events:
|
||||||
|
if events is None:
|
||||||
|
events = list(connector_events)
|
||||||
|
else:
|
||||||
|
events.extend(connector_events)
|
||||||
|
|
||||||
|
# publish collected KV cache events
|
||||||
if events:
|
if events:
|
||||||
batch = KVEventBatch(ts=time.time(), events=events)
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
self.kv_event_publisher.publish(batch)
|
self.kv_event_publisher.publish(batch)
|
||||||
|
Reference in New Issue
Block a user