[P/D] Support CPU Transfer in NixlConnector (#18293)
Signed-off-by: Juncheng Gu <juncgu@gmail.com> Signed-off-by: Richard Liu <ricliu@google.com> Co-authored-by: Richard Liu <39319471+richardsliu@users.noreply.github.com> Co-authored-by: Richard Liu <ricliu@google.com>
This commit is contained in:
@ -10,6 +10,7 @@ jinja2>=3.1.6
|
||||
ray[default]
|
||||
ray[data]
|
||||
setuptools==78.1.0
|
||||
nixl==0.3.0
|
||||
|
||||
# Install torch_xla
|
||||
--pre
|
||||
|
@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
launch_baseline() {
|
||||
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_USE_V1=1 \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${BASELINE_HOST} \
|
||||
--port ${BASELINE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--disable-log-requests \
|
||||
--enforce-eager"
|
||||
echo ${BASELINE_BASE_CMD}
|
||||
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
|
||||
}
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_USE_V1=1 \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--disable-log-requests \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_USE_V1=1 \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--disable-log-requests \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
run_tests(){
|
||||
local service_url=$1
|
||||
local mode=$2
|
||||
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
|
||||
}
|
||||
|
||||
|
||||
# run non-disagg. baseline & save outputs
|
||||
launch_baseline
|
||||
sleep 2
|
||||
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
|
||||
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
|
||||
cleanup
|
||||
sleep 10
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
|
||||
echo "-----P/D success----"
|
||||
|
||||
rm ${OUTPUT_FILE}
|
||||
cleanup
|
||||
|
||||
exit 0
|
128
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
128
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
@ -0,0 +1,128 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_USE_V1=1 \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--disable-log-requests \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_USE_V1=1 \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--disable-log-requests \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
|
||||
PREFILL_HOST=${PREFILL_HOST} \
|
||||
PREFILL_PORT=${PREFILL_PORT} \
|
||||
DECODE_HOST=${DECODE_HOST} \
|
||||
DECODE_PORT=${DECODE_PORT} \
|
||||
PROXY_HOST=${PROXY_HOST} \
|
||||
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
162
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
162
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
MAX_OUTPUT_LEN = 30
|
||||
|
||||
SAMPLE_PROMPTS = (
|
||||
"Red Hat is the best company in the world to work for because it works on "
|
||||
"open source software, which means that all the contributions are "
|
||||
"delivered to the community. As a result, when working on projects like "
|
||||
"vLLM we are able to meet many amazing people from various organizations "
|
||||
"like AMD, Google, NVIDIA, ",
|
||||
"We hold these truths to be self-evident, that all men are created equal, "
|
||||
"that they are endowed by their Creator with certain unalienable Rights, "
|
||||
"that among these are Life, Liberty and the pursuit of Happiness.--That "
|
||||
"to secure these rights, Governments are instituted among Men, deriving "
|
||||
"their just powers from the consent of the governed, ",
|
||||
)
|
||||
|
||||
|
||||
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
|
||||
"""
|
||||
Checks if the vLLM server is ready by sending a GET request to the
|
||||
/health endpoint.
|
||||
|
||||
Args:
|
||||
url (str): The base URL of the vLLM server.
|
||||
timeout (int): Timeout in seconds for the request.
|
||||
retries (int): Number of retries if the server is not ready.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is ready, False otherwise.
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
print(f"Attempt {attempt + 1}: Server returned status code "
|
||||
"{response.status_code}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
|
||||
time.sleep(1) # Wait before retrying
|
||||
return False
|
||||
|
||||
|
||||
def run_simple_prompt(base_url: str, model_name: str,
|
||||
input_prompt: str) -> str:
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=input_prompt,
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42)
|
||||
|
||||
# print("-" * 50)
|
||||
# print(f"Completion results for {model_name}:")
|
||||
# print(completion)
|
||||
# print("-" * 50)
|
||||
return completion.choices[0].text
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script demonstrates how to accept two optional string arguments
|
||||
("service_url" and "file_name") from the command line, each with a
|
||||
default value of an empty string, using the argparse module.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="vLLM client script")
|
||||
|
||||
parser.add_argument(
|
||||
"--service_url", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="The vLLM service URL.")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="model_name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode", # Name of the second argument
|
||||
type=str,
|
||||
default="baseline",
|
||||
help="mode: baseline==non-disagg, or disagg",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file_name", # Name of the second argument
|
||||
type=str,
|
||||
default=".vllm_output.txt",
|
||||
help="the file that saves the output tokens ",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for arg in vars(args):
|
||||
print(f"{arg}: {getattr(args, arg)}")
|
||||
|
||||
if args.mode == "baseline":
|
||||
# non-disagg
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
else:
|
||||
# disagg proxy
|
||||
health_check_url = f"{args.service_url}/healthcheck"
|
||||
if not os.path.exists(args.file_name):
|
||||
raise ValueError(
|
||||
f"In disagg mode, the output file {args.file_name} from "
|
||||
"non-disagg. baseline does not exist.")
|
||||
|
||||
service_url = f"{args.service_url}/v1"
|
||||
|
||||
if not check_vllm_server(health_check_url):
|
||||
raise RuntimeError(
|
||||
f"vllm server: {args.service_url} is not ready yet!")
|
||||
|
||||
output_strs = dict()
|
||||
for prompt in SAMPLE_PROMPTS:
|
||||
output_str = run_simple_prompt(base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
input_prompt=prompt)
|
||||
print(f"Prompt: {prompt}, output: {output_str}")
|
||||
output_strs[prompt] = output_str
|
||||
|
||||
if args.mode == "baseline":
|
||||
# baseline: save outputs
|
||||
try:
|
||||
with open(args.file_name, 'w') as json_file:
|
||||
json.dump(output_strs, json_file, indent=4)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
else:
|
||||
# disagg. verify outputs
|
||||
baseline_outputs = None
|
||||
try:
|
||||
with open(args.file_name) as json_file:
|
||||
baseline_outputs = json.load(json_file)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
assert isinstance(baseline_outputs, dict)
|
||||
assert len(baseline_outputs) == len(output_strs)
|
||||
for prompt, output in baseline_outputs.items():
|
||||
assert prompt in output_strs, f"{prompt} not included"
|
||||
assert output == output_strs[prompt], (
|
||||
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -4,8 +4,11 @@ import os
|
||||
|
||||
import openai
|
||||
|
||||
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
|
||||
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
|
||||
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
|
||||
DECODE_PORT = os.getenv("DECODE_PORT", None)
|
||||
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
|
||||
PROXY_PORT = os.getenv("PROXY_PORT", None)
|
||||
|
||||
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
|
||||
@ -21,15 +24,15 @@ def test_edge_cases():
|
||||
# Set the OpenAI API key and base URL
|
||||
decode_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{DECODE_PORT}/v1",
|
||||
base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
|
||||
)
|
||||
prefill_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{PREFILL_PORT}/v1",
|
||||
base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
|
||||
)
|
||||
proxy_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{PROXY_PORT}/v1",
|
||||
base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
|
||||
)
|
||||
|
||||
# Get the list of models
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
@ -11,9 +12,8 @@ import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -32,7 +32,7 @@ The class provides the following primitives:
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -46,6 +46,12 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||
CopyBlocksOp = Callable[[
|
||||
dict[str, torch.Tensor], dict[
|
||||
str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"]
|
||||
], None]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -127,6 +133,13 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
@ -20,14 +21,14 @@ from vllm import envs
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
@ -40,6 +41,7 @@ if TYPE_CHECKING:
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -52,6 +54,13 @@ except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
# Supported xPUs and types of kv transfer buffer.
|
||||
# {xPU: tuple of supported kv buffer types}
|
||||
_NIXL_SUPPORTED_XPUS = {
|
||||
"cuda": ("cuda", ),
|
||||
"tpu": ("cpu", ),
|
||||
}
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
@ -80,6 +89,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
def add_new_req(
|
||||
@ -87,8 +97,12 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
save_to_host: bool = False,
|
||||
):
|
||||
self.reqs_to_recv[request_id] = ReqMeta(
|
||||
# save and load are mutually exclusive
|
||||
assert load_remote_cache ^ save_to_host
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
@ -97,6 +111,10 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
# P workers don't need to receive tp_size from proxy here.
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
)
|
||||
if save_to_host:
|
||||
self.reqs_to_save[request_id] = _req
|
||||
if load_remote_cache:
|
||||
self.reqs_to_recv[request_id] = _req
|
||||
|
||||
|
||||
class NixlConnector(KVConnectorBase_V1):
|
||||
@ -155,6 +173,10 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
@ -177,8 +199,11 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
"""NixlConnector does not save explicitly."""
|
||||
pass
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
|
||||
if self.connector_worker.use_host_buffer and \
|
||||
self.connector_worker.copy_blocks:
|
||||
self.connector_worker.save_kv_to_host(self._connector_metadata)
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
@ -193,12 +218,15 @@ class NixlConnectorScheduler:
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.use_host_buffer = \
|
||||
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
|
||||
@ -248,7 +276,25 @@ class NixlConnectorScheduler:
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens, params)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if not params:
|
||||
return
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
|
||||
# figure out full computed blocks to save
|
||||
block_ids = blocks.get_block_ids()[0]
|
||||
all_full = request.num_tokens % self.block_size == 0
|
||||
full_block_ids = (block_ids if all_full else block_ids[:-1])
|
||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
||||
# kv blocks, so host blocks won't be flushed as long as its device
|
||||
# block is not overwritten; and it will be safe to skip saving them
|
||||
# to host xfer buffer.
|
||||
if full_block_ids:
|
||||
self._reqs_need_save[request.request_id] = \
|
||||
(request, full_block_ids)
|
||||
elif params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port")):
|
||||
@ -260,6 +306,7 @@ class NixlConnectorScheduler:
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, local_block_ids)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
@ -284,10 +331,21 @@ class NixlConnectorScheduler:
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=False,
|
||||
save_to_host=True,
|
||||
)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
@ -379,9 +437,36 @@ class NixlConnectorWorker:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.num_blocks = 0
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.device_type = current_platform.device_type
|
||||
self.kv_buffer_device: str = \
|
||||
vllm_config.kv_transfer_config.kv_buffer_device
|
||||
if self.device_type not in _NIXL_SUPPORTED_XPUS:
|
||||
raise RuntimeError(f"{self.device_type} is not supported.")
|
||||
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
|
||||
self.device_type]:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
"is not supported.")
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# cpu kv buffer for xfer
|
||||
# used when xPU memory can not be registered under nixl
|
||||
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
"is not supported.")
|
||||
|
||||
# Note: host xfer buffer ops when use_host_buffer is True
|
||||
self.copy_blocks: Optional[CopyBlocksOp] = None
|
||||
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank will still only pull from a single remote TP worker.
|
||||
@ -404,6 +489,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_metadata: dict[ReqId, ReqMeta] = {}
|
||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
@ -440,6 +526,7 @@ class NixlConnectorWorker:
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
@ -529,6 +616,31 @@ class NixlConnectorWorker:
|
||||
# Remote rank -> agent name.
|
||||
return {p_remote_rank: remote_agent_name}
|
||||
|
||||
def initialize_host_xfer_buffer(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize transfer buffer in CPU mem for accelerators
|
||||
NOT directly supported by NIXL (e.g., tpu)
|
||||
"""
|
||||
xfer_buffers: dict[str, torch.Tensor] = {}
|
||||
try:
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
kv_shape = kv_cache.shape
|
||||
kv_dtype = kv_cache.dtype
|
||||
xfer_buffers[layer_name] = torch.empty(kv_shape,
|
||||
dtype=kv_dtype,
|
||||
device="cpu")
|
||||
except MemoryError as e:
|
||||
logger.error("NIXLConnectorWorker gets %s.", e)
|
||||
raise
|
||||
|
||||
self.host_xfer_buffers = xfer_buffers
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""Assign copy (d2h, h2d) operations when host buffer is used."""
|
||||
assert self.use_host_buffer
|
||||
self.copy_blocks = copy_operation
|
||||
|
||||
def _background_nixl_handshake(self, req_id: str,
|
||||
remote_engine_id: EngineId, meta: ReqMeta):
|
||||
# Do NIXL handshake in background and add to _ready_requests when done.
|
||||
@ -562,47 +674,76 @@ class NixlConnectorWorker:
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
if self.use_host_buffer:
|
||||
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
|
||||
assert len(self.host_xfer_buffers) == len(kv_caches), (
|
||||
f"host_buffer: {len(self.host_xfer_buffers)}, "
|
||||
f"kv_caches: {len(kv_caches)}")
|
||||
xfer_buffers = self.host_xfer_buffers
|
||||
else:
|
||||
xfer_buffers = kv_caches
|
||||
assert not self.host_xfer_buffers, (
|
||||
"host_xfer_buffer should not be initialized when "
|
||||
f"kv_buffer_device is {self.kv_buffer_device}")
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
assert use_mla == self.use_mla
|
||||
|
||||
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
|
||||
# once it goes live, as a single kv layout is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
if self.device_type == "tpu":
|
||||
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
|
||||
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
|
||||
# tpu (v1) kv shape per layer:
|
||||
# (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
block_size, n_kv_heads_x_2, head_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
|
||||
elif self.device_type == "cuda":
|
||||
assert use_mla == self.use_mla
|
||||
# TODO (NickLucche) not compatible with hybrid allocator.
|
||||
# Enforce check once it goes live, as a single kv layout
|
||||
# is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} ({self.backend_name}) is not supported.")
|
||||
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
|
||||
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
|
||||
self.num_blocks, block_shape, first_kv_cache.shape)
|
||||
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
|
||||
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
|
||||
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
|
||||
self.use_host_buffer, self.num_blocks, block_shape,
|
||||
first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
self.device_kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
|
||||
@ -614,19 +755,21 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache_or_caches in xfer_buffers.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
cache_list = [cache_or_caches] if use_mla \
|
||||
or self._use_pallas_v1 or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
caches_data.append(
|
||||
(base_addr, region_len, cache.device.index, ""))
|
||||
# NOTE: use tp_rank for device_id since multi-node TP
|
||||
# is rarely used.
|
||||
caches_data.append((base_addr, region_len, self.tp_rank, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(self.kv_caches.keys())
|
||||
self.num_layers = len(xfer_buffers.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
@ -648,7 +791,8 @@ class NixlConnectorWorker:
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data,
|
||||
self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
logger.debug("Done registering descs")
|
||||
@ -666,11 +810,13 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
# TODO: does device_id matter to DRAM?
|
||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
|
||||
self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
@ -755,6 +901,8 @@ class NixlConnectorWorker:
|
||||
tp_ratio = divide(self._tp_size[self.engine_id],
|
||||
self._tp_size[engine_id])
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
assert not self._use_pallas_v1 or tp_ratio == 1, \
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
@ -813,13 +961,43 @@ class NixlConnectorWorker:
|
||||
self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
|
||||
self.nixl_memory_type)
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
remote_agent_name, descs)
|
||||
|
||||
return remote_agent_name
|
||||
|
||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||
"""copy recved kv from host buffer to device."""
|
||||
assert self.use_host_buffer
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
local_block_ids = meta.local_block_ids
|
||||
self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches,
|
||||
local_block_ids, local_block_ids, "h2d")
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"synced recved kv of request[%s] to device kv buffer,"
|
||||
"local_block_ids: %s. ", req_id,
|
||||
",".join(map(str, meta.local_block_ids)))
|
||||
|
||||
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
|
||||
"""copy kv from device to host buffer."""
|
||||
assert self.use_host_buffer
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
for req_id, meta in metadata.reqs_to_save.items():
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"save_load_kv for request[%s] to host xfer buffer."
|
||||
"local_block_ids: %s. ", req_id,
|
||||
",".join(map(str, meta.local_block_ids)))
|
||||
# blocking
|
||||
self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers,
|
||||
meta.local_block_ids, meta.local_block_ids, "d2h")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
@ -834,6 +1012,12 @@ class NixlConnectorWorker:
|
||||
"and %s requests done recving", self.tp_rank,
|
||||
len(done_sending), len(done_recving))
|
||||
|
||||
if self.use_host_buffer:
|
||||
for req_id in done_recving:
|
||||
meta = self._recving_metadata.pop(req_id)
|
||||
assert meta, f"{req_id} not found in recving_metadata list"
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
while self._reqs_to_send:
|
||||
@ -904,6 +1088,8 @@ class NixlConnectorWorker:
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||
remote_engine_id, len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids))
|
||||
if self.use_host_buffer:
|
||||
self._recving_metadata[req_id] = meta
|
||||
if remote_engine_id not in self._remote_agents:
|
||||
# Initiate handshake with remote engine to exchange metadata.
|
||||
with self._handshake_lock:
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@ -23,12 +22,10 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.forward_context import DPMetadata, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -66,6 +63,8 @@ from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
@ -88,7 +87,7 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1357,7 +1356,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
@ -1745,52 +1745,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
|
70
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
70
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define KV connector functionality mixin for model runners.
|
||||
"""
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class KVConnectorModelRunnerMixin:
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
@ -3,7 +3,7 @@
|
||||
import bisect
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -20,6 +20,8 @@ from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import (ParallelConfig, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
@ -46,6 +48,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@ -97,7 +101,7 @@ MIN_NUM_SEQS = 8
|
||||
# The dummy_run should be comprehensive, ensuring all potential input shapes and
|
||||
# branch predictions are included as subgraph inputs to facilitate
|
||||
# pre-compilation.
|
||||
class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -971,8 +975,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update cached state
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
@ -986,6 +994,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_index = 0
|
||||
combined_selected_tokens: list[torch.Tensor] = []
|
||||
combined_logprobs: list[LogprobsLists] = []
|
||||
|
||||
# NOTE: setup current batch's metadata for kv connector.
|
||||
# Currently, only verified with NixlConnector
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
while start_index < self.input_batch.num_reqs:
|
||||
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
|
||||
end_index = self._prepare_inputs(scheduler_output, start_index)
|
||||
@ -1032,6 +1046,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
start_index = end_index
|
||||
|
||||
# NOTE: current kv load and save get h2d/d2h copies involved.
|
||||
# Those copies are blocking. Once they become async., kv_save
|
||||
# should be called right after each single forward pass,
|
||||
# instead of the forwards of the entire input batch.
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
|
||||
if tpu_sampling_metadata.logprobs:
|
||||
|
||||
@ -1126,6 +1148,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
# Check there are no new graphs compiled - all the graphs should be
|
||||
@ -1637,6 +1661,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
for cache in self.kv_caches:
|
||||
xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
def reset_dynamo_cache(self):
|
||||
if self.is_multimodal_model:
|
||||
compiled_model = self.model.get_language_model().model
|
||||
@ -1851,6 +1879,75 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
return paddings[index]
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _insert_blocks_to_tpu(
|
||||
cpu_cache: torch.Tensor,
|
||||
tpu_cache: torch.Tensor,
|
||||
cpu_block_indices: torch.Tensor,
|
||||
tpu_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
||||
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
|
||||
tpu_cache.device)
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _swap_out_tpu_blocks(
|
||||
tpu_cache: torch.Tensor,
|
||||
cpu_cache: torch.Tensor,
|
||||
tpu_block_indices: torch.Tensor,
|
||||
cpu_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
""" tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
||||
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if not src_kv_caches or not dst_kv_caches or \
|
||||
not src_block_ids or not dst_block_ids or \
|
||||
len(src_block_ids) != len(dst_block_ids):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device)
|
||||
|
||||
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
|
||||
_swap_out_tpu_blocks
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def _get_padded_num_kv_cache_update_slices(
|
||||
num_tokens: int, max_num_reqs: int, page_size: int,
|
||||
num_slices_per_kv_cache_update_block: int) -> int:
|
||||
|
@ -12,9 +12,11 @@ import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
@ -118,7 +120,7 @@ class TPUWorker:
|
||||
|
||||
# Initialize the distributed environment.
|
||||
self._init_tpu_worker_distributed_environment(
|
||||
self.parallel_config, self.rank, self.distributed_init_method,
|
||||
self.vllm_config, self.rank, self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
@ -242,7 +244,9 @@ class TPUWorker:
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.is_driver_worker else None
|
||||
# every worker's output is needed when kv_transfer_group is setup
|
||||
return output if self.is_driver_worker or has_kv_transfer_group(
|
||||
) else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
@ -294,7 +298,7 @@ class TPUWorker:
|
||||
|
||||
def _init_tpu_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
@ -306,6 +310,7 @@ class TPUWorker:
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
@ -317,6 +322,8 @@ class TPUWorker:
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
|
Reference in New Issue
Block a user