[V0 deprecation][P/D] Deprecate v0 KVConnectorBase
code (1/2) (#21785)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
@ -749,7 +749,6 @@ steps:
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
|
||||
|
@ -1,120 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from subprocess import Popen
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
# Fixture to set up environment variables and teardown servers after tests
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_servers():
|
||||
if torch.cuda.device_count() < 2:
|
||||
pytest.skip("Skipping test: fewer than 2 GPUs available")
|
||||
|
||||
# Set up environment variables
|
||||
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
|
||||
shell=True).decode().strip()
|
||||
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
|
||||
|
||||
# Start prefill instance
|
||||
prefill_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"--port",
|
||||
"8100",
|
||||
"--gpu-memory-utilization",
|
||||
"0.5",
|
||||
"--max-model-len",
|
||||
"1000",
|
||||
"--kv-transfer-config",
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
|
||||
'"kv_rank":0,"kv_parallel_size":2}',
|
||||
]
|
||||
prefill_env = os.environ.copy()
|
||||
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
prefill_proc = Popen(prefill_cmd, env=prefill_env)
|
||||
|
||||
# Start decode instance
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"--port",
|
||||
"8200",
|
||||
"--gpu-memory-utilization",
|
||||
"0.5",
|
||||
"--max-model-len",
|
||||
"1000",
|
||||
"--kv-transfer-config",
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
|
||||
'"kv_rank":1,"kv_parallel_size":2}',
|
||||
]
|
||||
decode_env = os.environ.copy()
|
||||
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
decode_proc = Popen(decode_cmd, env=decode_env)
|
||||
|
||||
# Wait for servers to be ready
|
||||
assert wait_for_server(8100), "Prefill server did not start in time"
|
||||
assert wait_for_server(8200), "Decode server did not start in time"
|
||||
|
||||
# Yield to the test function and handle teardown after tests
|
||||
yield
|
||||
|
||||
# Cleanup: kill the processes
|
||||
prefill_proc.terminate()
|
||||
decode_proc.terminate()
|
||||
|
||||
# Additional cleanup if needed
|
||||
prefill_proc.wait()
|
||||
decode_proc.wait()
|
||||
|
||||
|
||||
# Helper function to wait for server
|
||||
def wait_for_server(port, timeout=240):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{port}/v1/completions")
|
||||
if response.status_code in [200, 405]:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
# Test function to send curl requests and validate responses
|
||||
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
|
||||
def test_disaggregated_prefilling(prompt):
|
||||
# Send to prefill
|
||||
response = requests.post("http://localhost:8100/v1/completions",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 1,
|
||||
"temperature": 0
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Send to decode
|
||||
response = requests.post("http://localhost:8200/v1/completions",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
})
|
||||
assert response.status_code == 200
|
@ -1,142 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
"""Defines the base type for KV cache connectors."""
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
KVConnectorBase = KVConnectorBase_V1
|
||||
KVConnectorBaseType = KVConnectorBase_V1
|
||||
|
||||
|
||||
class KVConnectorBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KV connector.
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
"""
|
||||
Send KV caches and hidden states to the connector.
|
||||
|
||||
This method processes the input tokens, KV caches, and
|
||||
hidden/intermediate states for a given model and sends the data to the
|
||||
decode instance.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module): The model executable containing
|
||||
start and end layer information.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
||||
metadata from vLLM.
|
||||
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
|
||||
for each layer.
|
||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
||||
IntermediateTensors]):
|
||||
The hidden or intermediate states associated with the tokens.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
"""
|
||||
Receive KV caches and hidden states from the connector.
|
||||
|
||||
This method attempts to retrieve KV caches and hidden states for input
|
||||
tokens. If all required KV caches and hidden states are received, it
|
||||
will bypass model input, else it will fall back to normal vLLM model
|
||||
forwarding.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module):
|
||||
The model executable from vLLM modelrunner.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
The model input from vLLM modelrunner.
|
||||
kv_caches (list[torch.Tensor]):
|
||||
List of KV caches for each layer.
|
||||
|
||||
Returns:
|
||||
- hidden_or_intermediate_states (torch.Tensor or
|
||||
IntermediateTensors):
|
||||
Concatenated hidden states if all required data is retrieved,
|
||||
otherwise `None`.
|
||||
- bypass_model_exec (bool):
|
||||
Indicates whether the model execution can be skipped (True) or
|
||||
needs to be redone (False).
|
||||
- model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
Optionally adjusted input metadata for re-execution when
|
||||
`bypass_model_exec=False`.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
||||
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]
|
||||
|
@ -5,14 +5,10 @@ import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import KVConnectorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -20,7 +16,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str,
|
||||
@ -29,28 +25,23 @@ class KVConnectorFactory:
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[KVConnectorBaseType]:
|
||||
def loader() -> type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector_v0(cls, rank: int, local_rank: int,
|
||||
config: "VllmConfig") -> KVConnectorBase:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("Attempting to initialize a V0 Connector, "
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
) -> KVConnectorBase:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Attempting to initialize a V1 Connector, "
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
connector_cls = cls.get_connector_class(config.kv_transfer_config)
|
||||
assert issubclass(connector_cls, KVConnectorBase)
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
@ -61,21 +52,7 @@ class KVConnectorFactory:
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
@classmethod
|
||||
def create_connector_v1(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
) -> KVConnectorBase_V1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Attempting to initialize a V1 Connector, "
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||
assert issubclass(connector_cls, KVConnectorBase)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
@ -92,25 +69,6 @@ class KVConnectorFactory:
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
KVConnectorFactory.register_connector(
|
||||
"PyNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
|
||||
"LMCacheConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeStoreConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
||||
"MooncakeStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"SharedStorageConnector",
|
||||
|
@ -1,99 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
LMCache KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
||||
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
|
||||
(2) offload and share KV caches.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.transfer_config = config.kv_transfer_config
|
||||
self.vllm_config = config
|
||||
|
||||
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from lmcache.integration.vllm.vllm_adapter import (
|
||||
RetrieveStatus, StoreStatus, init_lmcache_engine,
|
||||
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
|
||||
lmcache_store_kv)
|
||||
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
|
||||
self.transfer_config)
|
||||
|
||||
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
|
||||
self.engine = init_lmcache_engine(config.model_config,
|
||||
config.parallel_config,
|
||||
config.cache_config)
|
||||
self.lmcache_engine_name = ENGINE_NAME
|
||||
self.lmcache_engine_builder = LMCacheEngineBuilder
|
||||
|
||||
self.model_config = config.model_config
|
||||
self.parallel_config = config.parallel_config
|
||||
self.cache_config = config.cache_config
|
||||
self.lmcache_retrieve_kv = lmcache_retrieve_kv
|
||||
self.lmcache_store_kv = lmcache_store_kv
|
||||
self.lmcache_should_retrieve = lmcache_should_retrieve
|
||||
self.lmcache_should_store = lmcache_should_store
|
||||
self.store_status = StoreStatus
|
||||
self.retrieve_status = RetrieveStatus
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
||||
model_input, bypass_model_exec, hidden_or_intermediate_states =\
|
||||
self.lmcache_retrieve_kv(
|
||||
model_executable, model_input, self.cache_config, kv_caches,
|
||||
retrieve_status)
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
store_status = self.lmcache_should_store(model_input)
|
||||
self.lmcache_store_kv(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.cache_config,
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
store_status,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)
|
@ -1,203 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
||||
database-style KVStore.
|
||||
"""
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
model_aware_kv_ops_helper as kv_helper)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeStoreConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.kv_transfer_config = config.kv_transfer_config
|
||||
self.kv_helper = kv_helper(config)
|
||||
self.local_tp_rank = local_rank
|
||||
|
||||
# Init kv_store
|
||||
if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_store:
|
||||
raise ValueError(
|
||||
"To use MooncakeStoreConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
|
||||
MooncakeStore)
|
||||
logger.info(
|
||||
"Initializing KVStoreConnector under kv_transfer_config %s",
|
||||
self.kv_transfer_config)
|
||||
self.kv_store = MooncakeStore(config)
|
||||
else:
|
||||
logger.error("Can not find %s",
|
||||
self.kv_transfer_config.kv_connector)
|
||||
|
||||
assert self.kv_store is not None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
self.kv_store.close()
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
store_key_prefix = self.tensor_hash(current_tokens)
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||
kv_cache, num_heads, head_size)
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
kvcache_to_sent = torch.stack((keys, values), dim=0)
|
||||
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
|
||||
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
|
||||
|
||||
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
self.kv_store.put(hidden_key,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
# get roi for current seq
|
||||
load_key_prefix = self.tensor_hash(current_tokens)
|
||||
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
|
||||
remote_kv = self.kv_store.get(load_kvcache_key)
|
||||
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
hidden = self.kv_store.get(hidden_key)
|
||||
|
||||
if remote_kv is None or hidden is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
continue
|
||||
|
||||
num_computed_tokens = current_tokens.shape[0]
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# call self.kv_store to get kv layer by layer
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
layer = model_executable.model.layers[layer_id]
|
||||
# get kvcache object
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
# get remote kvcache
|
||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
||||
layer_id]
|
||||
|
||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||
remote_v, layer, kv_cache,
|
||||
slot_mapping, start_pos,
|
||||
end_pos)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
@staticmethod
|
||||
def tensor_hash(tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
||||
hash_object = hashlib.blake2b(tensor_bytes)
|
||||
hash_hex = hash_object.hexdigest()
|
||||
return int(hash_hex[:16], 16)
|
@ -1,329 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
model_aware_kv_ops_helper as kv_helper)
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.kv_helper = kv_helper(config)
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
PyNcclPipe)
|
||||
logger.info(
|
||||
"Initializing PyNcclConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_distributed_pipe = os.getenv(
|
||||
'MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_distributed_pipe:
|
||||
raise ValueError(
|
||||
"To use MooncakeConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
|
||||
MooncakePipe)
|
||||
logger.info(
|
||||
"Initializing MooncakeConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.producer_buffer: Optional[SimpleBuffer] = None
|
||||
self.consumer_buffer: Optional[SimpleBuffer] = None
|
||||
|
||||
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
port_offset_base = 2 * rank
|
||||
|
||||
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
||||
# and the decode vLLM only uses recv pipe
|
||||
if self.config.is_kv_producer:
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.producer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.producer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
# We only need to initialize MooncakePipe once
|
||||
self.producer_signal_pipe = self.producer_data_pipe
|
||||
|
||||
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
|
||||
self.producer_data_pipe,
|
||||
self.config.kv_buffer_size)
|
||||
|
||||
else:
|
||||
|
||||
# the current vLLM instance is KV consumer, so it needs to connect
|
||||
# its recv pipe to the send pipe of KV producer
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.consumer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.consumer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.consumer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
self.consumer_signal_pipe = self.consumer_data_pipe
|
||||
|
||||
self.consumer_buffer = SimpleBuffer(
|
||||
self.consumer_signal_pipe,
|
||||
self.consumer_data_pipe,
|
||||
self.config.kv_buffer_size,
|
||||
)
|
||||
|
||||
def select(self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||
"consumer buffer before calling select."
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
assert self.producer_buffer is not None, "Please initialize the "\
|
||||
"producer buffer before calling insert."
|
||||
|
||||
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
||||
kv_cache, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
self.insert(current_tokens,
|
||||
torch.ones_like(current_tokens,
|
||||
dtype=bool), keys, values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self.select(current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool))
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
num_computed_tokens_list.append(0)
|
||||
continue
|
||||
|
||||
roi: torch.Tensor = ret[1]
|
||||
keys: torch.Tensor = ret[2]
|
||||
values: torch.Tensor = ret[3]
|
||||
hidden: torch.Tensor = ret[4]
|
||||
|
||||
num_computed_tokens = roi.shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for cur_layer in range(start_layer, end_layer):
|
||||
|
||||
layer_id = cur_layer - start_layer
|
||||
kv_cache = kv_caches[layer_id]
|
||||
layer = model_executable.model.layers[cur_layer]
|
||||
|
||||
# get remote kvcache
|
||||
remote_k, remote_v = keys[layer_id], values[layer_id]
|
||||
|
||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
||||
remote_v, layer, kv_cache,
|
||||
slot_mapping, start_pos,
|
||||
end_pos)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.producer_data_pipe.close()
|
||||
self.consumer_data_pipe.close()
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_signal_pipe.close()
|
||||
self.consumer_signal_pipe.close()
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
||||
# close the data_pipe.
|
||||
pass
|
@ -13,8 +13,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
@ -106,9 +106,8 @@ def get_kv_connector_cache_layout():
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
vllm_config)
|
||||
required_kvcache_layout = (
|
||||
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once("Connectors do not specify a " \
|
||||
|
@ -52,7 +52,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
KVConnectorFactory.create_connector(temp_config, role))
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
@ -223,9 +223,9 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||
required_kvcache_layout = KVConnectorFactory.get_connector_class(
|
||||
kv_transfer_config).get_required_kvcache_layout(
|
||||
temp_vllm_config)
|
||||
required_kvcache_layout = (
|
||||
KVConnectorBase_V1.get_required_kvcache_layout(
|
||||
temp_vllm_config))
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
|
||||
|
@ -1,77 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A centralized entrypoint to perform distributed KV cache transfer.
|
||||
|
||||
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
||||
1. `send_kv_caches_and_hidden_states`
|
||||
2. `recv_kv_caches_and_hidden_states
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVTransferAgent:
|
||||
"""
|
||||
A class designated for distributed KV transfer
|
||||
|
||||
Target use cases:
|
||||
1. Disaggregated prefill
|
||||
2. Remote KV cache storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.kv_transfer_config is None:
|
||||
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
|
||||
" cannot initialize KVConnector.")
|
||||
|
||||
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
||||
"TransferAgent should only be used when kv_connector is set."
|
||||
|
||||
self.connector = KVConnectorFactory.create_connector_v0(
|
||||
rank, local_rank, config)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
self.connector.send_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches,
|
||||
hidden_or_intermediate_states)
|
||||
|
||||
def close(self) -> None:
|
||||
self.connector.close()
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
return self.connector.recv_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches)
|
@ -8,7 +8,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||
and _KV_CONNECTOR_AGENT is None):
|
||||
if envs.VLLM_USE_V1:
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
||||
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||
else:
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
|
||||
rank=get_world_group().rank,
|
||||
local_rank=get_world_group().local_rank,
|
||||
config=vllm_config,
|
||||
)
|
||||
raise ValueError("V0 is no longer supported")
|
||||
|
@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"Multiple KV cache groups are not currently supported "
|
||||
"with KV connectors")
|
||||
self.connector = KVConnectorFactory.create_connector_v1(
|
||||
self.connector = KVConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
|
@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||
@ -31,7 +31,7 @@ class KVConnectorModelRunnerMixin:
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
@ -93,7 +93,7 @@ class KVConnectorModelRunnerMixin:
|
||||
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
Reference in New Issue
Block a user