mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
@ -0,0 +1,207 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Requirement: 2x GPUs.
|
||||
|
||||
|
||||
# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
|
||||
# Resource: 2x GPU
|
||||
# Approaches:
|
||||
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
|
||||
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
|
||||
# Prefilling instance: max_output_token=1
|
||||
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
|
||||
|
||||
set -ex
|
||||
|
||||
kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
for port in 7010 7011 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
|
||||
sleep 1
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
launch_chunked_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
gpu_memory_utilization=0.6
|
||||
max_model_len=10000
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8100 \
|
||||
--max-model-len $max_model_len \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization $gpu_memory_utilization &
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8200 \
|
||||
--max-model-len $max_model_len \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization $gpu_memory_utilization &
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
python3 ../round_robin_proxy.py &
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_disagg_prefill_http() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
python3 ../disagg_prefill_proxy_server.py &
|
||||
sleep 1
|
||||
}
|
||||
|
||||
|
||||
|
||||
launch_disagg_prefill_zmq() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
gpu_memory_utilization=0.6
|
||||
max_model_len=10000
|
||||
# disagg prefill
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8100 \
|
||||
--zmq-server-port 7010 \
|
||||
--max-model-len $max_model_len \
|
||||
--gpu-memory-utilization $gpu_memory_utilization \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
VLLM_LOGGING_LEVEL=DEBUG CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model $model \
|
||||
--port 8200 \
|
||||
--zmq-server-port 7011 \
|
||||
--max-model-len $max_model_len \
|
||||
--gpu-memory-utilization $gpu_memory_utilization \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
python3 \
|
||||
-m vllm.entrypoints.disagg_connector \
|
||||
--port 8000 \
|
||||
--prefill-addr 127.0.0.1:7010 \
|
||||
--decode-addr 127.0.0.1:7011 &
|
||||
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
wait_for_server 8000
|
||||
sleep 1
|
||||
}
|
||||
|
||||
|
||||
benchmark() {
|
||||
results_folder="./results"
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
dataset_name="sonnet"
|
||||
dataset_path="../../sonnet_4x.txt"
|
||||
num_prompts=100
|
||||
qps=$1
|
||||
prefix_len=50
|
||||
input_len=1024
|
||||
output_len=$2
|
||||
tag=$3
|
||||
|
||||
python3 ../../benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8000 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename "$tag"-qps-"$qps".json \
|
||||
--request-rate "$qps"
|
||||
|
||||
sleep 2
|
||||
}
|
||||
|
||||
|
||||
main() {
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get -y install jq)
|
||||
(which socat) || (apt-get -y install socat)
|
||||
(which lsof) || (apt-get -y install lsof)
|
||||
pip install quart httpx matplotlib aiohttp datasets
|
||||
cd "$(dirname "$0")"
|
||||
cd ../..
|
||||
# create sonnet-4x.txt so that we can sample 2048 tokens for input
|
||||
echo "" > sonnet_4x.txt
|
||||
for _ in {1..4}
|
||||
do
|
||||
cat sonnet.txt >> sonnet_4x.txt
|
||||
done
|
||||
cd disagg_benchmarks/zmq
|
||||
|
||||
rm -rf results
|
||||
mkdir results
|
||||
mkdir results/http_zmq_chunk
|
||||
mkdir results/http_zmq
|
||||
|
||||
default_output_len=6
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
launch_chunked_prefill
|
||||
for qps in 2 4 6 8 10 12; do
|
||||
benchmark $qps $default_output_len chunked_prefill
|
||||
done
|
||||
kill_gpu_processes
|
||||
|
||||
|
||||
launch_disagg_prefill_http
|
||||
for qps in 2 4 6 8 10 12; do
|
||||
benchmark $qps $default_output_len disagg_prefill_http
|
||||
done
|
||||
kill_gpu_processes
|
||||
|
||||
launch_disagg_prefill_zmq
|
||||
for qps in 2 4 6 8 10 12; do
|
||||
benchmark $qps $default_output_len disagg_prefill_zmq
|
||||
done
|
||||
kill_gpu_processes
|
||||
|
||||
python3 visualize_benchmark_results_zmq_http.py
|
||||
|
||||
}
|
||||
|
||||
|
||||
main "$@"
|
@ -11,7 +11,7 @@ import aiohttp
|
||||
# 3. python test_request.py
|
||||
async def test_connect_completions(session):
|
||||
try:
|
||||
base_url = "http://localhost:8001/v1/connect/completions"
|
||||
base_url = "http://localhost:8001/v1/completions"
|
||||
body = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
|
@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = []
|
||||
for name in ['disagg_prefill_http', 'disagg_prefill_zmq', 'chunked_prefill']:
|
||||
for qps in [2, 4, 6, 8, 10, 12]:
|
||||
with open(f"results/{name}-qps-{qps}.json") as f:
|
||||
x = json.load(f)
|
||||
x['name'] = name
|
||||
x['qps'] = qps
|
||||
data.append(x)
|
||||
|
||||
df = pd.DataFrame.from_dict(data)
|
||||
dis_http_df = df[df['name'] == 'disagg_prefill_http']
|
||||
dis_zmq_df = df[df['name'] == 'disagg_prefill_zmq']
|
||||
chu_df = df[df['name'] == 'chunked_prefill']
|
||||
|
||||
|
||||
plt.style.use('bmh')
|
||||
plt.rcParams['font.size'] = 20
|
||||
|
||||
for key in [
|
||||
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
|
||||
'median_itl_ms', 'p99_itl_ms'
|
||||
]:
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 7))
|
||||
plt.plot(dis_http_df['qps'],
|
||||
dis_http_df[key],
|
||||
label='disagg_prefill_http',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
plt.plot(dis_zmq_df['qps'],
|
||||
dis_zmq_df[key],
|
||||
label='disagg_prefill_zmq',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
plt.plot(chu_df['qps'],
|
||||
chu_df[key],
|
||||
label='chunked_prefill',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
ax.legend()
|
||||
|
||||
ax.set_xlabel('QPS')
|
||||
ax.set_ylabel(key)
|
||||
ax.set_ylim(bottom=0)
|
||||
fig.savefig(f'results/http_zmq_chunk/{key}.png')
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
fig1, ax1 = plt.subplots(figsize=(11, 7))
|
||||
plt.plot(dis_http_df['qps'],
|
||||
dis_http_df[key],
|
||||
label='disagg_prefill_http',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
plt.plot(dis_zmq_df['qps'],
|
||||
dis_zmq_df[key],
|
||||
label='disagg_prefill_zmq',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
ax1.legend()
|
||||
|
||||
ax1.set_xlabel('QPS')
|
||||
ax1.set_ylabel(key)
|
||||
ax1.set_ylim(bottom=0)
|
||||
fig1.savefig(f'results/http_zmq/{key}.png')
|
||||
plt.close(fig1)
|
@ -11,25 +11,27 @@ from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# default prefill and decode addr
|
||||
time_out = 3
|
||||
fastapi_port = 8001
|
||||
time_out = 180
|
||||
fastapi_port = 8000
|
||||
prefill_addr = "ipc://localhost:7010"
|
||||
socket_prefill_num = 20
|
||||
socket_prefill_num = 100
|
||||
decode_addr = "ipc://localhost:7020"
|
||||
socket_decode_num = 20
|
||||
socket_decode_num = 100
|
||||
context_type_json = "application/json"
|
||||
context_type_error = "error"
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.connect')
|
||||
logger = init_logger('vllm.entrypoints.disagg_connector')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -146,7 +148,7 @@ async def decode(route: str, header: dict, original_request_data: dict):
|
||||
media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.post('/v1/connect/completions')
|
||||
@app.post('/v1/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
@ -210,5 +212,25 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# url = 'tcp://127.0.0.1:5555'
|
||||
uvicorn.run(app, host="0.0.0.0", port=fastapi_port)
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM disagg zmq server.")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="The fastapi server port")
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The prefill address IP:PORT")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The decode address IP:PORT")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
uvloop.run(run_disagg_connector(args))
|
||||
|
||||
# uvicorn.run(app, host="0.0.0.0", port=fastapi_port)
|
||||
|
@ -91,7 +91,7 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(20)
|
||||
for i in range(100)
|
||||
]
|
||||
logger.info("zmq tasks: %s", tasks)
|
||||
# thread safety proxy create socket in the background:
|
||||
|
@ -94,8 +94,6 @@ async def worker_routine(worker_addr: str, app: FastAPI,
|
||||
json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.info("worker-%d Sending response chunk: [ %s ]",
|
||||
i, chunk)
|
||||
await socket.send_multipart([identity,
|
||||
b"text/event-stream",
|
||||
chunk.encode('utf-8')])
|
||||
|
Reference in New Issue
Block a user