# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State communication in vLLM v1 The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache. Might be called multiple times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after output is received from worker-side connectors. request_finished() - called when a request is finished, with the computed kv cache blocks for the request. Returns whether KV cache should be freed now or will be freed asynchronously and optionally returns KV transfer params. take_events() - returns new KV events that were collected by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. start_load_kv() - starts loading all KVs (maybe async) wait_for_layer_load() - blocks until layer i load is done save_kv_layer() - starts saving KV for layer i (maybe async) wait_for_save() - blocks until all saves are done get_finished() - called with ids of finished requests, returns ids of requests that have completed async sending/recving. """ import enum from abc import ABC, abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction CopyBlocksOp = Callable[[ dict[str, torch.Tensor], dict[ str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"] ], None] logger = init_logger(__name__) class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 # Connector running in the worker process WORKER = 1 class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ pass class KVConnectorBase_V1(ABC): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") self._connector_metadata: Optional[KVConnectorMetadata] = None self._vllm_config = vllm_config self._role = role @property def role(self) -> KVConnectorRole: return self._role # ============================== # Worker-side methods # ============================== def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. Args: connector_metadata (dict): the connector metadata. """ self._connector_metadata = connector_metadata def clear_connector_metadata(self) -> None: """Clear the connector metadata. This function should be called by the model runner every time after the model execution. """ self._connector_metadata = None def _get_connector_metadata(self) -> KVConnectorMetadata: """Get the connector metadata. This function should only be called inside the connector. Returns: ConnectorMetadata: the connector metadata. """ # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). Args: kv_caches: dictionary of layer names, kv cache """ return def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ Set the xPU-specific ops for copying KV between host and device. Needed when host buffer is used for kv transfer (e.g., in NixlConnector) """ return @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the forward pass to enable async loading during model execution. Args: forward_context (ForwardContext): the forward context. **kwargs: additional arguments for the load operation Note: The number of elements in kv_caches and layer_names should be the same. """ pass @abstractmethod def wait_for_layer_load(self, layer_name: str) -> None: """ Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. This interface will be useful for layer-by-layer pipelining. Args: layer_name: the name of that layer """ pass @abstractmethod def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ pass @abstractmethod def wait_for_save(self): """ Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward. This prevents overwrites of paged KV buffer before saving done. """ pass def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. The scheduler process (via the Executors) will use this output to track which workers are done. Returns: ids of requests that have finished asynchronous transfer (requests that previously returned True from request_finished()), tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ return None, None # ============================== # Scheduler-side methods # ============================== @abstractmethod def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: A tuple with the following elements: - The number of tokens that can be loaded from the external KV cache beyond what is already computed. - `True` if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0. """ pass @abstractmethod def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after block allocation. If get_num_new_matched_tokens previously returned True for a request, this function may be called twice for that same request - first when blocks are allocated for the connector tokens to be asynchronously loaded into, and second when any additional blocks are allocated, after the load/transfer is complete. Args: request (Request): the request object. blocks (KVCacheBlocks): the blocks allocated for the request. num_external_tokens (int): the number of tokens that will be loaded from the external KV cache. """ pass @abstractmethod def build_connector_meta( self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """ Build the connector metadata for this step. This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector. Args: scheduler_output (SchedulerOutput): the scheduler output object. """ pass def update_connector_output(self, connector_output: KVConnectorOutput): """ Update KVConnector state from worker-side connectors output. Args: connector_output (KVConnectorOutput): the worker-side connectors output. """ return def request_finished( self, request: "Request", block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: """ Called when a request has finished, before its blocks are freed. Returns: True if the request is being saved/sent asynchronously and blocks should not be freed until the request_id is returned from get_finished(). Optional KVTransferParams to be included in the request outputs returned by the engine. """ return False, None def take_events(self) -> Iterable["KVCacheEvent"]: """ Take the KV cache events from the connector. Yields: New KV cache events since the last call. """ return () @classmethod def get_required_kvcache_layout( cls, vllm_config: "VllmConfig") -> Optional[str]: """ Get the required KV cache layout for this connector. Args: vllm_config (VllmConfig): the vllm config. Returns: str: the required KV cache layout. e.g. HND, or NHD. None if the connector does not require a specific layout. """ if cls is KVConnectorBase_V1: raise TypeError("get_required_kvcache_layout should not be called " "on the abstract base class") return None