Implementation of simple load balance routing proxy server (#1953) (#2124)

### What this PR does / why we need it?
The PR is the cherry-pick from v0.9.1
https://github.com/vllm-project/vllm-ascend/pull/1953

This PR introduce a new load balance proxy server example implementation
for disaggregated pd, which support simple token&kv_cache aware load
balance routing strategy for the disaggregated pd system compared with
origin round robin toy_proxy.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
tested on real workload and unittest

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-08-04 10:35:53 +08:00
committed by GitHub
parent af04ee9e7a
commit 4b3a210c33
2 changed files with 518 additions and 275 deletions

View File

@ -0,0 +1,518 @@
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
#
# Tutorial: Using the Load Balance Proxy Server Example
#
# This proxy server is designed to distribute requests between multiple
# "prefiller" and "decoder" backend servers for large language model inference.
# It is useful for scaling out inference workloads and balancing load across
# multiple backend instances.
#
# Features:
# - Load balances requests to multiple prefiller and decoder servers.
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
# - Streams responses from backend servers to clients.
#
# Prerequisites:
# - Python 3.8+
# - Install dependencies:
# pip install fastapi httpx uvicorn vllm
#
# Step 1: Start Your Backend Servers
# ----------------------------------
# You need to have at least one prefiller and one decoder backend running.
# These can be mock servers or actual vLLM servers.
#
# For testing, you can use the provided mock server:
#
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
#
# Step 2: Start the Proxy Server
# ------------------------------
# Run the proxy server, specifying the host/port for each prefiller and decoder:
#
# python load_balance_proxy_server_example.py \
# --host 0.0.0.0 --port 9000 \
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
# --prefiller-ports 8100 8101 \
# --decoder-hosts 127.0.0.1 127.0.0.1 \
# --decoder-ports 8200 8201
#
# This will start the proxy on port 9000, load balancing between two prefiller
# and two decoder servers.
#
# Step 3: Send a Request to the Proxy
# -----------------------------------
# You can now send OpenAI-compatible requests to the proxy. For example:
#
# curl -X POST http://localhost:9000/v1/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "prompt": "The quick brown fox jumps over the lazy dog",
# "max_tokens": 16
# }'
#
# Or for chat completions:
#
# curl -X POST http://localhost:9000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "messages": [{"role": "user", "content": "Hello!"}],
# "max_tokens": 16
# }'
#
# Step 4: Health Check
# --------------------
# To check if the proxy is running and see how many backend instances are
# connected, use:
#
# curl http://localhost:9000/healthcheck
#
# This will return a JSON object with the status and the number of prefiller
# and decoder instances.
#
# Notes:
# - You can scale the number of prefiller and decoder servers as needed.
# - The proxy will round-robin requests to balance load.
# - For production, ensure your backend servers are robust and secure.
#
# For more details, see the code and comments in this file.
import argparse
import asyncio
import heapq
import os
import sys
from contextlib import asynccontextmanager
from typing import List
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
self.aborted_requests = set() # Track aborted requests
# Removed individual server lock - will use global locks instead
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
self.req_id_counter = 0
# Removed selection locks - no longer needed for synchronous methods
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
def _update_prefiller_priority(self, server_idx: int):
"""Update the priority of a prefiller server in the heap."""
server = self.prefillers[server_idx]
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
"""
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
"""
# No lock needed - atomic operation
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
self.prefillers[server_idx].aborted_requests.clear()
return aborted_requests
async def next_req_id(self):
async with self.req_id_lock:
self.req_id_counter += 1
return str(self.req_id_counter)
def select_prefiller(self, token_count): # Changed to synchronous
# No lock needed - entire function is atomic
if not self.prefiller_heap:
raise RuntimeError("No prefiller servers available")
priority, chosen, server = heapq.heappop(self.prefiller_heap)
# Update the chosen server atomically
self.prefillers[chosen].active_tokens += token_count
self.prefillers[chosen].active_kv_cache += token_count
# Update priority and re-add to heap
self._update_prefiller_priority(chosen)
return chosen
def release_prefiller(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
self.prefillers[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if self.prefillers[idx].active_kv_cache > 0:
self.prefillers[idx].active_kv_cache -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)
def select_decoder(self, token_count): # Changed to synchronous
# No lock needed - entire function is atomic
if not self.decoder_heap:
raise RuntimeError("No decoder servers available")
priority, chosen, server = heapq.heappop(self.decoder_heap)
# Update the chosen server atomically
self.decoders[chosen].active_tokens += token_count
# Update priority and re-add to heap
self._update_decoder_priority(chosen)
return chosen
def release_decoder(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
self.decoders[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_decoder_priority(idx)
# Omni_infer's calculate_input_scores function
def calculate_prefill_scores(self, request_length: int) -> float:
length_score = request_length / 4.0
input_score = length_score * 0.0345 + 120.0745
return input_score
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
proxy_state = None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
yield
for p in proxy_state.prefillers:
await p.client.aclose()
for d in proxy_state.decoders:
await d.client.aclose()
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
first_chunk_sent = True
yield chunk
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
raise e
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
async def generate_stream():
nonlocal released_kv
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
prefiller_idx, prefiller_score)
released_kv = True
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
proxy_state.release_prefiller_kv(prefiller_idx,
prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -1,275 +0,0 @@
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
limit = httpx.Limits(max_connections=100000,
max_keepalive_connections=100000)
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None,
base_url=decoder_base_url,
limits=limit),
'host':
host,
'port':
port,
'id':
i
})
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(
range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(
range(len(app.state.decode_clients)))
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients.")
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info['client'].aclose()
for client_info in app.state.decode_clients:
await client_info['client'].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# For prefiller instances
parser.add_argument("--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
"--prefiller-port",
type=int,
nargs="+",
default=[8100])
# For decoder instances
parser.add_argument("--decoder-hosts",
"--decoder-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
"--decoder-port",
type=int,
nargs="+",
default=[8200])
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == 'prefill':
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == 'decode':
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client_info['client'].post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with client_info['client'].stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info, api,
req_data, request_id)
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
api,
req_data,
request_id=request_id):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)