[P/D] Support CPU Transfer in NixlConnector (#18293)

Signed-off-by: Juncheng Gu <juncgu@gmail.com>
Signed-off-by: Richard Liu <ricliu@google.com>
Co-authored-by: Richard Liu <39319471+richardsliu@users.noreply.github.com>
Co-authored-by: Richard Liu <ricliu@google.com>
This commit is contained in:
Juncheng Gu
2025-07-24 09:58:42 -07:00
committed by GitHub
parent 1e9ea8e69d
commit 6066284914
12 changed files with 893 additions and 110 deletions

View File

@ -10,6 +10,7 @@ jinja2>=3.1.6
ray[default]
ray[data]
setuptools==78.1.0
nixl==0.3.0
# Install torch_xla
--pre

View File

@ -0,0 +1,162 @@
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_baseline() {
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${BASELINE_HOST} \
--port ${BASELINE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--enforce-eager"
echo ${BASELINE_BASE_CMD}
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
run_tests(){
local service_url=$1
local mode=$2
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
}
# run non-disagg. baseline & save outputs
launch_baseline
sleep 2
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
cleanup
sleep 10
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
echo "-----P/D success----"
rm ${OUTPUT_FILE}
cleanup
exit 0

View File

@ -0,0 +1,128 @@
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--disable-log-requests \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
PREFILL_HOST=${PREFILL_HOST} \
PREFILL_PORT=${PREFILL_PORT} \
DECODE_HOST=${DECODE_HOST} \
DECODE_PORT=${DECODE_PORT} \
PROXY_HOST=${PROXY_HOST} \
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py

View File

@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import os
import time
import openai
import requests
MAX_OUTPUT_LEN = 30
SAMPLE_PROMPTS = (
"Red Hat is the best company in the world to work for because it works on "
"open source software, which means that all the contributions are "
"delivered to the community. As a result, when working on projects like "
"vLLM we are able to meet many amazing people from various organizations "
"like AMD, Google, NVIDIA, ",
"We hold these truths to be self-evident, that all men are created equal, "
"that they are endowed by their Creator with certain unalienable Rights, "
"that among these are Life, Liberty and the pursuit of Happiness.--That "
"to secure these rights, Governments are instituted among Men, deriving "
"their just powers from the consent of the governed, ",
)
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
"""
Checks if the vLLM server is ready by sending a GET request to the
/health endpoint.
Args:
url (str): The base URL of the vLLM server.
timeout (int): Timeout in seconds for the request.
retries (int): Number of retries if the server is not ready.
Returns:
bool: True if the server is ready, False otherwise.
"""
for attempt in range(retries):
try:
response = requests.get(url, timeout=timeout)
if response.status_code == 200:
return True
else:
print(f"Attempt {attempt + 1}: Server returned status code "
"{response.status_code}")
except requests.exceptions.RequestException as e:
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
time.sleep(1) # Wait before retrying
return False
def run_simple_prompt(base_url: str, model_name: str,
input_prompt: str) -> str:
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
completion = client.completions.create(model=model_name,
prompt=input_prompt,
max_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
# print("-" * 50)
# print(f"Completion results for {model_name}:")
# print(completion)
# print("-" * 50)
return completion.choices[0].text
def main():
"""
This script demonstrates how to accept two optional string arguments
("service_url" and "file_name") from the command line, each with a
default value of an empty string, using the argparse module.
"""
parser = argparse.ArgumentParser(description="vLLM client script")
parser.add_argument(
"--service_url", # Name of the first argument
type=str,
required=True,
help="The vLLM service URL.")
parser.add_argument(
"--model_name", # Name of the first argument
type=str,
required=True,
help="model_name",
)
parser.add_argument(
"--mode", # Name of the second argument
type=str,
default="baseline",
help="mode: baseline==non-disagg, or disagg",
)
parser.add_argument(
"--file_name", # Name of the second argument
type=str,
default=".vllm_output.txt",
help="the file that saves the output tokens ",
)
args = parser.parse_args()
for arg in vars(args):
print(f"{arg}: {getattr(args, arg)}")
if args.mode == "baseline":
# non-disagg
health_check_url = f"{args.service_url}/health"
else:
# disagg proxy
health_check_url = f"{args.service_url}/healthcheck"
if not os.path.exists(args.file_name):
raise ValueError(
f"In disagg mode, the output file {args.file_name} from "
"non-disagg. baseline does not exist.")
service_url = f"{args.service_url}/v1"
if not check_vllm_server(health_check_url):
raise RuntimeError(
f"vllm server: {args.service_url} is not ready yet!")
output_strs = dict()
for prompt in SAMPLE_PROMPTS:
output_str = run_simple_prompt(base_url=service_url,
model_name=args.model_name,
input_prompt=prompt)
print(f"Prompt: {prompt}, output: {output_str}")
output_strs[prompt] = output_str
if args.mode == "baseline":
# baseline: save outputs
try:
with open(args.file_name, 'w') as json_file:
json.dump(output_strs, json_file, indent=4)
except OSError as e:
print(f"Error writing to file: {e}")
raise
else:
# disagg. verify outputs
baseline_outputs = None
try:
with open(args.file_name) as json_file:
baseline_outputs = json.load(json_file)
except OSError as e:
print(f"Error writing to file: {e}")
raise
assert isinstance(baseline_outputs, dict)
assert len(baseline_outputs) == len(output_strs)
for prompt, output in baseline_outputs.items():
assert prompt in output_strs, f"{prompt} not included"
assert output == output_strs[prompt], (
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
)
if __name__ == "__main__":
main()

View File

@ -4,8 +4,11 @@ import os
import openai
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
@ -21,15 +24,15 @@ def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1",
base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1",
base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1",
base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
)
# Get the list of models

View File

@ -3,6 +3,7 @@
import argparse
import itertools
import logging
import os
import uuid
from contextlib import asynccontextmanager
@ -11,9 +12,8 @@ import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@asynccontextmanager

View File

@ -32,7 +32,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
@ -46,6 +46,12 @@ if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
CopyBlocksOp = Callable[[
dict[str, torch.Tensor], dict[
str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"]
], None]
logger = init_logger(__name__)
@ -127,6 +133,13 @@ class KVConnectorBase_V1(ABC):
"""
return
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
import math
import queue
import threading
@ -20,14 +21,14 @@ from vllm import envs
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@ -40,6 +41,7 @@ if TYPE_CHECKING:
Transfer = tuple[int, float] # (xfer_handle, start_time)
EngineId = str
ReqId = str
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
@ -52,6 +54,13 @@ except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
# Supported xPUs and types of kv transfer buffer.
# {xPU: tuple of supported kv buffer types}
_NIXL_SUPPORTED_XPUS = {
"cuda": ("cuda", ),
"tpu": ("cpu", ),
}
class NixlAgentMetadata(
msgspec.Struct,
@ -80,6 +89,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
def add_new_req(
@ -87,8 +97,12 @@ class NixlConnectorMetadata(KVConnectorMetadata):
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
save_to_host: bool = False,
):
self.reqs_to_recv[request_id] = ReqMeta(
# save and load are mutually exclusive
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
@ -97,6 +111,10 @@ class NixlConnectorMetadata(KVConnectorMetadata):
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
if save_to_host:
self.reqs_to_save[request_id] = _req
if load_remote_cache:
self.reqs_to_recv[request_id] = _req
class NixlConnector(KVConnectorBase_V1):
@ -155,6 +173,10 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
@ -177,8 +199,11 @@ class NixlConnector(KVConnectorBase_V1):
pass
def wait_for_save(self):
"""NixlConnector does not save explicitly."""
pass
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
if self.connector_worker.use_host_buffer and \
self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata)
class NixlConnectorScheduler:
@ -193,12 +218,15 @@ class NixlConnectorScheduler:
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
self.use_host_buffer = \
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
@ -248,7 +276,25 @@ class NixlConnectorScheduler:
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, params)
if params is not None and params.get("do_remote_prefill"):
if not params:
return
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# figure out full computed blocks to save
block_ids = blocks.get_block_ids()[0]
all_full = request.num_tokens % self.block_size == 0
full_block_ids = (block_ids if all_full else block_ids[:-1])
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if full_block_ids:
self._reqs_need_save[request.request_id] = \
(request, full_block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
@ -260,6 +306,7 @@ class NixlConnectorScheduler:
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
@ -284,10 +331,21 @@ class NixlConnectorScheduler:
kv_transfer_params=req.kv_transfer_params,
)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=False,
save_to_host=True,
)
meta.reqs_to_send = self._reqs_need_send
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_need_send = {}
return meta
@ -379,9 +437,36 @@ class NixlConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}
self.device_type = current_platform.device_type
self.kv_buffer_device: str = \
vllm_config.kv_transfer_config.kv_buffer_device
if self.device_type not in _NIXL_SUPPORTED_XPUS:
raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
self.device_type]:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.")
self.device_kv_caches: dict[str, torch.Tensor] = {}
# cpu kv buffer for xfer
# used when xPU memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
else:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.")
# Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: Optional[CopyBlocksOp] = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
@ -404,6 +489,7 @@ class NixlConnectorWorker:
# In progress transfers.
# [req_id -> list[handle]]
self._recving_metadata: dict[ReqId, ReqMeta] = {}
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
@ -440,6 +526,7 @@ class NixlConnectorWorker:
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
logger.debug("Detected attention backend %s", self.backend_name)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
@ -529,6 +616,31 @@ class NixlConnectorWorker:
# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}
def initialize_host_xfer_buffer(
self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Initialize transfer buffer in CPU mem for accelerators
NOT directly supported by NIXL (e.g., tpu)
"""
xfer_buffers: dict[str, torch.Tensor] = {}
try:
for layer_name, kv_cache in kv_caches.items():
kv_shape = kv_cache.shape
kv_dtype = kv_cache.dtype
xfer_buffers[layer_name] = torch.empty(kv_shape,
dtype=kv_dtype,
device="cpu")
except MemoryError as e:
logger.error("NIXLConnectorWorker gets %s.", e)
raise
self.host_xfer_buffers = xfer_buffers
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""Assign copy (d2h, h2d) operations when host buffer is used."""
assert self.use_host_buffer
self.copy_blocks = copy_operation
def _background_nixl_handshake(self, req_id: str,
remote_engine_id: EngineId, meta: ReqMeta):
# Do NIXL handshake in background and add to _ready_requests when done.
@ -562,47 +674,76 @@ class NixlConnectorWorker:
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
f"host_buffer: {len(self.host_xfer_buffers)}, "
f"kv_caches: {len(kv_caches)}")
xfer_buffers = self.host_xfer_buffers
else:
xfer_buffers = kv_caches
assert not self.host_xfer_buffers, (
"host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
@ -614,19 +755,21 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
cache_list = [cache_or_caches] if use_mla \
or self._use_pallas_v1 or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append(
(base_addr, region_len, cache.device.index, ""))
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())
self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
@ -648,7 +791,8 @@ class NixlConnectorWorker:
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
@ -666,11 +810,13 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
@ -755,6 +901,8 @@ class NixlConnectorWorker:
tp_ratio = divide(self._tp_size[self.engine_id],
self._tp_size[engine_id])
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas_v1 or tp_ratio == 1, \
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
# Handle tp_size>num_kv_heads: replicate KV cache.
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
@ -813,13 +961,43 @@ class NixlConnectorWorker:
self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
return remote_agent_name
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
assert self.use_host_buffer
assert self.copy_blocks is not None
local_block_ids = meta.local_block_ids
self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches,
local_block_ids, local_block_ids, "h2d")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ", req_id,
",".join(map(str, meta.local_block_ids)))
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
"""copy kv from device to host buffer."""
assert self.use_host_buffer
assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items():
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ", req_id,
",".join(map(str, meta.local_block_ids)))
# blocking
self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers,
meta.local_block_ids, meta.local_block_ids, "d2h")
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
@ -834,6 +1012,12 @@ class NixlConnectorWorker:
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))
if self.use_host_buffer:
for req_id in done_recving:
meta = self._recving_metadata.pop(req_id)
assert meta, f"{req_id} not found in recving_metadata list"
self.sync_recved_kv_to_device(req_id, meta)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
while self._reqs_to_send:
@ -904,6 +1088,8 @@ class NixlConnectorWorker:
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
if self.use_host_buffer:
self._recving_metadata[req_id] = meta
if remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
with self._handshake_lock:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import time
from contextlib import contextmanager
@ -23,12 +22,10 @@ from vllm.config import (CompilationLevel, VllmConfig,
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -66,6 +63,8 @@ from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
@ -88,7 +87,7 @@ else:
logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin):
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
@ -1357,7 +1356,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
@ -1745,52 +1745,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],

View File

@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Define KV connector functionality mixin for model runners.
"""
import copy
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin:
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output

View File

@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from unittest.mock import patch
import numpy as np
@ -20,6 +20,8 @@ from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
@ -46,6 +48,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
@ -97,7 +101,7 @@ MIN_NUM_SEQS = 8
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
class TPUModelRunner(LoRAModelRunnerMixin):
class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
@ -971,8 +975,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.is_multimodal_model:
# Run the multimodal encoder if any.
@ -986,6 +994,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start_index = 0
combined_selected_tokens: list[torch.Tensor] = []
combined_logprobs: list[LogprobsLists] = []
# NOTE: setup current batch's metadata for kv connector.
# Currently, only verified with NixlConnector
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
while start_index < self.input_batch.num_reqs:
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index)
@ -1032,6 +1046,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start_index = end_index
# NOTE: current kv load and save get h2d/d2h copies involved.
# Those copies are blocking. Once they become async., kv_save
# should be called right after each single forward pass,
# instead of the forwards of the entire input batch.
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
if tpu_sampling_metadata.logprobs:
@ -1126,6 +1148,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
# Check there are no new graphs compiled - all the graphs should be
@ -1637,6 +1661,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model
@ -1851,6 +1879,75 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_blocks_to_tpu(
cpu_cache: torch.Tensor,
tpu_cache: torch.Tensor,
cpu_block_indices: torch.Tensor,
tpu_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
tpu_cache.device)
@torch.compile(backend="openxla")
def _swap_out_tpu_blocks(
tpu_cache: torch.Tensor,
cpu_cache: torch.Tensor,
tpu_block_indices: torch.Tensor,
cpu_block_indices: torch.Tensor,
) -> None:
""" tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
_swap_out_tpu_blocks
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:

View File

@ -12,9 +12,11 @@ import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
@ -118,7 +120,7 @@ class TPUWorker:
# Initialize the distributed environment.
self._init_tpu_worker_distributed_environment(
self.parallel_config, self.rank, self.distributed_init_method,
self.vllm_config, self.rank, self.distributed_init_method,
self.local_rank)
# Device initialization should happen after initializing
@ -242,7 +244,9 @@ class TPUWorker:
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
# every worker's output is needed when kv_transfer_group is setup
return output if self.is_driver_worker or has_kv_transfer_group(
) else None
def profile(self, is_start: bool = True):
if self.rank < 1:
@ -294,7 +298,7 @@ class TPUWorker:
def _init_tpu_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
@ -306,6 +310,7 @@ class TPUWorker:
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
parallel_config = vllm_config.parallel_config
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
@ -317,6 +322,8 @@ class TPUWorker:
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
try:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker