[Feature] Support KV cache offloading and disagg prefill with LMCache connector. (#12953)

This commit is contained in:
Jiayi Yao
2025-02-25 02:38:42 -06:00
committed by GitHub
parent 3173c3b34e
commit 2f42a4888c
5 changed files with 310 additions and 2 deletions

View File

@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import time
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set local CPU memory limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# This example script runs two requests with a shared prefix.
shared_prompt = "Hello, how are you?" * 1000
first_prompt = [
shared_prompt + "Hello, my name is",
]
second_prompt = [
shared_prompt + "Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
enable_chunked_prefill=False,
gpu_memory_utilization=0.8)
outputs = llm.generate(first_prompt, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("First request done.")
time.sleep(1)
outputs = llm.generate(second_prompt, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("Second request done.")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)

View File

@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of disaggregated prefilling
with LMCache.
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and launch an additional LMCache server.
KV cache is transferred in the following manner:
VLLM prefill node -> LMCache server -> VLLM decode node.
Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
from multiprocessing import Event, Process
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# The port to start LMCache server
port = 8100
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Disable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "False"
# Set local CPU memory buffer limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# Set the remote URL for LMCache server
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
def run_prefill(prefill_done, prompts):
# We use GPU 0 for prefill node.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
#llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("Prefill node is finished.")
prefill_done.set()
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_decode(prefill_done, prompts, timeout=1):
# We use GPU 1 for decode node.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
print("Waiting for prefill node to finish...")
prefill_done.wait()
time.sleep(timeout)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_lmcache_server(port):
server_proc = subprocess.Popen([
"python", "-m", "lmcache.experimental.server", "localhost",
str(port)
])
return server_proc
if __name__ == "__main__":
prompts = [
"Hello, how are you?" * 1000,
]
prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
lmcache_server_process = run_lmcache_server(port)
# Start prefill node
prefill_process.start()
# Start decode node
decode_process.start()
# Clean up the processes
decode_process.join()
prefill_process.terminate()
lmcache_server_process.terminate()
lmcache_server_process.wait()

View File

@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"LMCacheConnector",
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
"LMCacheConnector")

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
(2) offload and share KV caches.
"""
from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class LMCacheConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.transfer_config = config.kv_transfer_config
self.vllm_config = config
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.integration.vllm.vllm_adapter import (
RetrieveStatus, StoreStatus, init_lmcache_engine,
lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv)
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
self.transfer_config)
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
self.engine = init_lmcache_engine(config.model_config,
config.parallel_config,
config.cache_config)
self.lmcache_engine_name = ENGINE_NAME
self.lmcache_engine_builder = LMCacheEngineBuilder
self.model_config = config.model_config
self.parallel_config = config.parallel_config
self.cache_config = config.cache_config
self.lmcache_retrieve_kv = lmcache_retrieve_kv
self.lmcache_store_kv = lmcache_store_kv
self.lmcache_should_store = lmcache_should_store
self.store_status = StoreStatus
self.retrieve_status = RetrieveStatus
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
hidden_or_intermediate_states = None
# TODO (Jiayi): Need to support chunked prefill
retrieve_status = self.retrieve_status.PREFILL
model_input, bypass_model_exec = self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
num_reqs = 0
seq_group_list = model_input.sampling_metadata.seq_groups
assert seq_group_list is not None
for seq_group in seq_group_list:
seq_ids = seq_group.seq_ids
for seq_id in seq_ids:
num_reqs += 1
# TODO (Jiayi): Only normal prefill is supported for now
store_status = self.lmcache_should_store(model_input)
self.lmcache_store_kv(
self.model_config,
self.parallel_config,
self.cache_config,
model_executable,
model_input,
kv_caches,
store_status,
)
def close(self):
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)

View File

@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return
if all([
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
is None
vllm_config.kv_transfer_config.is_kv_transfer_instance,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,