[V0 deprecation][P/D] Deprecate v0 KVConnectorBase code (1/2) (#21785)

Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
lkchen
2025-08-04 19:11:33 -07:00
committed by GitHub
parent 5ea71ff46f
commit f4f4e7ef27
13 changed files with 31 additions and 1040 deletions

View File

@ -749,7 +749,6 @@ steps:
# this test fails consistently.
# TODO: investigate and fix
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s models/multimodal/generation/test_maverick.py

View File

@ -1,120 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import subprocess
import sys
import time
from subprocess import Popen
import pytest
import requests
import torch
# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 2:
pytest.skip("Skipping test: fewer than 2 GPUs available")
# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
shell=True).decode().strip()
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
# Start prefill instance
prefill_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Llama-3.2-1B-Instruct",
"--port",
"8100",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
'"kv_rank":0,"kv_parallel_size":2}',
]
prefill_env = os.environ.copy()
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
prefill_proc = Popen(prefill_cmd, env=prefill_env)
# Start decode instance
decode_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Llama-3.2-1B-Instruct",
"--port",
"8200",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
'"kv_rank":1,"kv_parallel_size":2}',
]
decode_env = os.environ.copy()
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
decode_proc = Popen(decode_cmd, env=decode_env)
# Wait for servers to be ready
assert wait_for_server(8100), "Prefill server did not start in time"
assert wait_for_server(8200), "Decode server did not start in time"
# Yield to the test function and handle teardown after tests
yield
# Cleanup: kill the processes
prefill_proc.terminate()
decode_proc.terminate()
# Additional cleanup if needed
prefill_proc.wait()
decode_proc.wait()
# Helper function to wait for server
def wait_for_server(port, timeout=240):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/v1/completions")
if response.status_code in [200, 405]:
return True
except requests.ConnectionError:
time.sleep(1)
return False
# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
# Send to prefill
response = requests.post("http://localhost:8100/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model": "meta-llama/Llama-3.2-1B-Instruct",
"prompt": prompt,
"max_tokens": 1,
"temperature": 0
})
assert response.status_code == 200
# Send to decode
response = requests.post("http://localhost:8200/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model": "meta-llama/Llama-3.2-1B-Instruct",
"prompt": prompt,
"max_tokens": 10,
"temperature": 0
})
assert response.status_code == 200

View File

@ -1,142 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
import torch
"""Defines the base type for KV cache connectors."""
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
KVConnectorBase = KVConnectorBase_V1
KVConnectorBaseType = KVConnectorBase_V1
class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.
The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
"""
raise NotImplementedError
@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (list[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""
raise NotImplementedError
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return None
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]

View File

@ -5,14 +5,10 @@ import importlib
from typing import TYPE_CHECKING, Callable
import vllm.envs as envs
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
from .base import KVConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -20,7 +16,7 @@ logger = init_logger(__name__)
class KVConnectorFactory:
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str,
@ -29,28 +25,23 @@ class KVConnectorFactory:
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[KVConnectorBaseType]:
def loader() -> type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector_v0(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V0 Connector, "
def create_connector(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_cls = cls.get_connector_class(config.kv_transfer_config)
assert issubclass(connector_cls, KVConnectorBase)
return connector_cls(rank, local_rank, config)
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
kv_transfer_config = config.kv_transfer_config
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
@ -61,21 +52,7 @@ class KVConnectorFactory:
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
@classmethod
def create_connector_v1(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config
connector_cls = cls.get_connector_class(kv_transfer_config)
assert issubclass(connector_cls, KVConnectorBase_V1)
assert issubclass(connector_cls, KVConnectorBase)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@ -92,25 +69,6 @@ class KVConnectorFactory:
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"LMCacheConnector",
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
"LMCacheConnector")
KVConnectorFactory.register_connector(
"MooncakeStoreConnector",
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector")
KVConnectorFactory.register_connector(
"SharedStorageConnector",

View File

@ -1,99 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
(2) offload and share KV caches.
"""
from typing import TYPE_CHECKING, Union
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class LMCacheConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.transfer_config = config.kv_transfer_config
self.vllm_config = config
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.integration.vllm.vllm_adapter import (
RetrieveStatus, StoreStatus, init_lmcache_engine,
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
lmcache_store_kv)
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
self.transfer_config)
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
self.engine = init_lmcache_engine(config.model_config,
config.parallel_config,
config.cache_config)
self.lmcache_engine_name = ENGINE_NAME
self.lmcache_engine_builder = LMCacheEngineBuilder
self.model_config = config.model_config
self.parallel_config = config.parallel_config
self.cache_config = config.cache_config
self.lmcache_retrieve_kv = lmcache_retrieve_kv
self.lmcache_store_kv = lmcache_store_kv
self.lmcache_should_retrieve = lmcache_should_retrieve
self.lmcache_should_store = lmcache_should_store
self.store_status = StoreStatus
self.retrieve_status = RetrieveStatus
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
retrieve_status = self.lmcache_should_retrieve(model_input)
model_input, bypass_model_exec, hidden_or_intermediate_states =\
self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
store_status = self.lmcache_should_store(model_input)
self.lmcache_store_kv(
self.model_config,
self.parallel_config,
self.cache_config,
model_executable,
model_input,
kv_caches,
store_status,
)
def close(self):
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)

View File

@ -1,203 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
"""
import hashlib
from typing import TYPE_CHECKING, Union
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class MooncakeStoreConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.kv_transfer_config = config.kv_transfer_config
self.kv_helper = kv_helper(config)
self.local_tp_rank = local_rank
# Init kv_store
if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector":
# Check if MOONCAKE_CONFIG_PATH is set
import os
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
if not use_mooncake_store:
raise ValueError(
"To use MooncakeStoreConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
else:
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
MooncakeStore)
logger.info(
"Initializing KVStoreConnector under kv_transfer_config %s",
self.kv_transfer_config)
self.kv_store = MooncakeStore(config)
else:
logger.error("Can not find %s",
self.kv_transfer_config.kv_connector)
assert self.kv_store is not None
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
connector when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
self.kv_store.close()
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
store_key_prefix = self.tensor_hash(current_tokens)
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
kvcache_to_sent = torch.stack((keys, values), dim=0)
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
self.kv_store.put(hidden_key,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
hidden_or_intermediate_states_for_one_req = []
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
# get roi for current seq
load_key_prefix = self.tensor_hash(current_tokens)
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
remote_kv = self.kv_store.get(load_kvcache_key)
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
hidden = self.kv_store.get(hidden_key)
if remote_kv is None or hidden is None:
# didn't find any match.
bypass_model_exec = False
continue
num_computed_tokens = current_tokens.shape[0]
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# call self.kv_store to get kv layer by layer
for layer_id in range(start_layer, end_layer):
layer = model_executable.model.layers[layer_id]
# get kvcache object
kv_cache = kv_caches[layer_id - start_layer]
# get remote kvcache
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
layer_id]
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
@staticmethod
def tensor_hash(tensor: torch.Tensor) -> int:
"""Calculate the hash value of the tensor."""
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
hash_object = hashlib.blake2b(tensor_bytes)
hash_hex = hash_object.hexdigest()
return int(hash_hex[:16], 16)

View File

@ -1,329 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from typing import TYPE_CHECKING, Optional, Union
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config.kv_transfer_config
self.kv_helper = kv_helper(config)
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
PyNcclPipe)
logger.info(
"Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
elif self.config.kv_connector == "MooncakeConnector":
# Check if MOONCAKE_CONFIG_PATH is set
import os
use_mooncake_distributed_pipe = os.getenv(
'MOONCAKE_CONFIG_PATH') is not None
if not use_mooncake_distributed_pipe:
raise ValueError(
"To use MooncakeConnector, you need to pass the ENV: "
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
else:
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
MooncakePipe)
logger.info(
"Initializing MooncakeConfig under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
# 2 pipes for every rank in the world
port_offset_base = 2 * rank
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.producer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
# We only need to initialize MooncakePipe once
self.producer_signal_pipe = self.producer_data_pipe
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
else:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producer
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
elif self.config.kv_connector == "MooncakeConnector":
self.consumer_data_pipe = MooncakePipe(
local_rank=local_rank,
config=self.config,
)
self.consumer_signal_pipe = self.consumer_data_pipe
self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
self.config.kv_buffer_size,
)
def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
roi: torch.Tensor = ret[1]
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]
num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for cur_layer in range(start_layer, end_layer):
layer_id = cur_layer - start_layer
kv_cache = kv_caches[layer_id]
layer = model_executable.model.layers[cur_layer]
# get remote kvcache
remote_k, remote_v = keys[layer_id], values[layer_id]
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close()
self.consumer_data_pipe.close()
if self.config.kv_connector == "PyNcclConnector":
self.producer_signal_pipe.close()
self.consumer_signal_pipe.close()
elif self.config.kv_connector == "MooncakeConnector":
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass

View File

@ -13,8 +13,8 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1)
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -106,9 +106,8 @@ def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
required_kvcache_layout = (
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \

View File

@ -52,7 +52,7 @@ class MultiConnector(KVConnectorBase_V1):
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id)
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
KVConnectorFactory.create_connector(temp_config, role))
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
@ -223,9 +223,9 @@ class MultiConnector(KVConnectorBase_V1):
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
required_kvcache_layout = KVConnectorFactory.get_connector_class(
kv_transfer_config).get_required_kvcache_layout(
temp_vllm_config)
required_kvcache_layout = (
KVConnectorBase_V1.get_required_kvcache_layout(
temp_vllm_config))
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)

View File

@ -1,77 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A centralized entrypoint to perform distributed KV cache transfer.
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.config import VllmConfig
import torch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class KVTransferAgent:
"""
A class designated for distributed KV transfer
Target use cases:
1. Disaggregated prefill
2. Remote KV cache storage
"""
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
self.config = config
if config.kv_transfer_config is None:
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
" cannot initialize KVConnector.")
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector_v0(
rank, local_rank, config)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
self.connector.send_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches,
hidden_or_intermediate_states)
def close(self) -> None:
self.connector.close()
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches)

View File

@ -8,7 +8,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config, role=KVConnectorRole.WORKER)
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config,
)
raise ValueError("V0 is no longer supported")

View File

@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1(
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(

View File

@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
@ -31,7 +31,7 @@ class KVConnectorModelRunnerMixin:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
@ -93,7 +93,7 @@ class KVConnectorModelRunnerMixin:
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)