mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it?
The PR is the cherry-pick from v0.9.1
https://github.com/vllm-project/vllm-ascend/pull/1953
This PR introduce a new load balance proxy server example implementation
for disaggregated pd, which support simple token&kv_cache aware load
balance routing strategy for the disaggregated pd system compared with
origin round robin toy_proxy.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
tested on real workload and unittest
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@ -0,0 +1,518 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Tutorial: Using the Load Balance Proxy Server Example
|
||||
#
|
||||
# This proxy server is designed to distribute requests between multiple
|
||||
# "prefiller" and "decoder" backend servers for large language model inference.
|
||||
# It is useful for scaling out inference workloads and balancing load across
|
||||
# multiple backend instances.
|
||||
#
|
||||
# Features:
|
||||
# - Load balances requests to multiple prefiller and decoder servers.
|
||||
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
||||
# - Streams responses from backend servers to clients.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - Python 3.8+
|
||||
# - Install dependencies:
|
||||
# pip install fastapi httpx uvicorn vllm
|
||||
#
|
||||
# Step 1: Start Your Backend Servers
|
||||
# ----------------------------------
|
||||
# You need to have at least one prefiller and one decoder backend running.
|
||||
# These can be mock servers or actual vLLM servers.
|
||||
#
|
||||
# For testing, you can use the provided mock server:
|
||||
#
|
||||
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
||||
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
||||
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
||||
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
||||
#
|
||||
# Step 2: Start the Proxy Server
|
||||
# ------------------------------
|
||||
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
||||
#
|
||||
# python load_balance_proxy_server_example.py \
|
||||
# --host 0.0.0.0 --port 9000 \
|
||||
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --prefiller-ports 8100 8101 \
|
||||
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --decoder-ports 8200 8201
|
||||
#
|
||||
# This will start the proxy on port 9000, load balancing between two prefiller
|
||||
# and two decoder servers.
|
||||
#
|
||||
# Step 3: Send a Request to the Proxy
|
||||
# -----------------------------------
|
||||
# You can now send OpenAI-compatible requests to the proxy. For example:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "prompt": "The quick brown fox jumps over the lazy dog",
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Or for chat completions:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/chat/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "messages": [{"role": "user", "content": "Hello!"}],
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Step 4: Health Check
|
||||
# --------------------
|
||||
# To check if the proxy is running and see how many backend instances are
|
||||
# connected, use:
|
||||
#
|
||||
# curl http://localhost:9000/healthcheck
|
||||
#
|
||||
# This will return a JSON object with the status and the number of prefiller
|
||||
# and decoder instances.
|
||||
#
|
||||
# Notes:
|
||||
# - You can scale the number of prefiller and decoder servers as needed.
|
||||
# - The proxy will round-robin requests to balance load.
|
||||
# - For production, ensure your backend servers are robust and secure.
|
||||
#
|
||||
# For more details, see the code and comments in this file.
|
||||
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import heapq
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f'http://{host}:{port}/v1'
|
||||
self.client = httpx.AsyncClient(timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=100000,
|
||||
max_keepalive_connections=100000))
|
||||
self.active_tokens = 0
|
||||
self.active_kv_cache = 0 # Only for prefiller
|
||||
self.active_requests = 0 # Number of active requests
|
||||
self.aborted_requests = set() # Track aborted requests
|
||||
# Removed individual server lock - will use global locks instead
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.prefillers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in prefiller_instances
|
||||
]
|
||||
self.decoders: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in decoder_instances
|
||||
]
|
||||
self.req_to_prefiller = {}
|
||||
self.req_id_lock = asyncio.Lock()
|
||||
self.req_id_counter = 0
|
||||
# Removed selection locks - no longer needed for synchronous methods
|
||||
|
||||
# Initialize priority queues for efficient server selection
|
||||
# Each entry is (priority_score, server_index, server_reference)
|
||||
# Lower priority score = higher priority (less loaded)
|
||||
self.prefiller_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.prefillers)]
|
||||
self.decoder_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.decoders)]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
heapq.heapify(self.decoder_heap)
|
||||
|
||||
def _update_prefiller_priority(self, server_idx: int):
|
||||
"""Update the priority of a prefiller server in the heap."""
|
||||
server = self.prefillers[server_idx]
|
||||
# Priority based on active_tokens and active_kv_cache
|
||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||
# Remove old entry and add new one
|
||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.prefiller_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def _update_decoder_priority(self, server_idx: int):
|
||||
"""Update the priority of a decoder server in the heap."""
|
||||
server = self.decoders[server_idx]
|
||||
priority = server.active_tokens
|
||||
# Remove old entry and add new one
|
||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.decoder_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def abort_prefiller_request(self, server_idx: int,
|
||||
request_id): # Changed to synchronous
|
||||
"""
|
||||
Mark a request as aborted. This will helps to release kv cache in
|
||||
prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||
|
||||
def aquire_aborted_prefiller_requests(
|
||||
self, server_idx: int): # Changed to synchronous
|
||||
"""
|
||||
Get the set of aborted requests and clear it.
|
||||
This is used to release kv cache in prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
||||
self.prefillers[server_idx].aborted_requests.clear()
|
||||
return aborted_requests
|
||||
|
||||
async def next_req_id(self):
|
||||
async with self.req_id_lock:
|
||||
self.req_id_counter += 1
|
||||
return str(self.req_id_counter)
|
||||
|
||||
def select_prefiller(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.prefiller_heap:
|
||||
raise RuntimeError("No prefiller servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.prefillers[chosen].active_tokens += token_count
|
||||
self.prefillers[chosen].active_kv_cache += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_prefiller_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if self.prefillers[idx].active_kv_cache > 0:
|
||||
self.prefillers[idx].active_kv_cache -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def select_decoder(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.decoder_heap:
|
||||
raise RuntimeError("No decoder servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.decoders[chosen].active_tokens += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_decoder_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_decoder(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.decoders[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_decoder_priority(idx)
|
||||
|
||||
# Omni_infer's calculate_input_scores function
|
||||
def calculate_prefill_scores(self, request_length: int) -> float:
|
||||
length_score = request_length / 4.0
|
||||
input_score = length_score * 0.0345 + 120.0745
|
||||
return input_score
|
||||
|
||||
def calculate_decode_scores(self, request_length: int) -> float:
|
||||
return request_length
|
||||
|
||||
|
||||
proxy_state = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8001])
|
||||
parser.add_argument("--decoder-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
args = parser.parse_args()
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
return args
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global proxy_state
|
||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
||||
global_args.decoder_instances)
|
||||
print(
|
||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
||||
)
|
||||
yield
|
||||
for p in proxy_state.prefillers:
|
||||
await p.client.aclose()
|
||||
for d in proxy_state.decoders:
|
||||
await d.client.aclose()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
async def send_request_to_service(client: httpx.AsyncClient,
|
||||
prefiller_id: int,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||
prefiller_id)
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
"aborted_request": list(aborted_requests),
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
last_exc = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = await client.post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||
last_exc = e
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for {endpoint}.")
|
||||
raise last_exc
|
||||
|
||||
|
||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with client.stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
first_chunk_sent = False
|
||||
async for chunk in response.aiter_bytes():
|
||||
first_chunk_sent = True
|
||||
yield chunk
|
||||
return # Success, exit after streaming
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# If any chunk has been sent, do not retry, just log and drop
|
||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
||||
logger.error(
|
||||
f"Streaming to client interrupted after response started: {str(e)}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
# Send request to prefiller
|
||||
response = await send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
prefiller_idx, prefiller_score)
|
||||
released_kv = True
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
|
||||
proxy_state.release_prefiller_kv(prefiller_idx,
|
||||
prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(proxy_state.prefillers),
|
||||
"decode_instances": len(proxy_state.decoders)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
@ -1,275 +0,0 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize client pools for prefiller and decoder services
|
||||
app.state.prefill_clients = []
|
||||
app.state.decode_clients = []
|
||||
limit = httpx.Limits(max_connections=100000,
|
||||
max_keepalive_connections=100000)
|
||||
|
||||
# Create prefill clients
|
||||
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||
prefiller_base_url = f'http://{host}:{port}/v1'
|
||||
app.state.prefill_clients.append({
|
||||
'client':
|
||||
httpx.AsyncClient(timeout=None,
|
||||
base_url=prefiller_base_url,
|
||||
limits=limit),
|
||||
'host':
|
||||
host,
|
||||
'port':
|
||||
port,
|
||||
'id':
|
||||
i
|
||||
})
|
||||
|
||||
# Create decode clients
|
||||
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||
decoder_base_url = f'http://{host}:{port}/v1'
|
||||
app.state.decode_clients.append({
|
||||
'client':
|
||||
httpx.AsyncClient(timeout=None,
|
||||
base_url=decoder_base_url,
|
||||
limits=limit),
|
||||
'host':
|
||||
host,
|
||||
'port':
|
||||
port,
|
||||
'id':
|
||||
i
|
||||
})
|
||||
|
||||
# Initialize round-robin iterators
|
||||
app.state.prefill_iterator = itertools.cycle(
|
||||
range(len(app.state.prefill_clients)))
|
||||
app.state.decode_iterator = itertools.cycle(
|
||||
range(len(app.state.decode_clients)))
|
||||
|
||||
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
||||
f"and {len(app.state.decode_clients)} decode clients.")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close all clients
|
||||
for client_info in app.state.prefill_clients:
|
||||
await client_info['client'].aclose()
|
||||
|
||||
for client_info in app.state.decode_clients:
|
||||
await client_info['client'].aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
"--prefiller-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
"--prefiller-port",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8100])
|
||||
|
||||
# For decoder instances
|
||||
parser.add_argument("--decoder-hosts",
|
||||
"--decoder-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports",
|
||||
"--decoder-port",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8200])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate and pair hosts with ports
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
|
||||
# Create tuples of (host, port) for each service type
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_next_client(app, service_type: str):
|
||||
"""
|
||||
Get the next client in round-robin fashion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app instance
|
||||
service_type: Either 'prefill' or 'decode'
|
||||
|
||||
Returns:
|
||||
The next client to use
|
||||
"""
|
||||
if service_type == 'prefill':
|
||||
client_idx = next(app.state.prefill_iterator)
|
||||
return app.state.prefill_clients[client_idx]
|
||||
elif service_type == 'decode':
|
||||
client_idx = next(app.state.decode_iterator)
|
||||
return app.state.decode_clients[client_idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
|
||||
async def send_request_to_service(client_info: dict, endpoint: str,
|
||||
req_data: dict, request_id: str):
|
||||
"""
|
||||
Send a request to a service using a client from the pool.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
|
||||
response = await client_info['client'].post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(client_info: dict, endpoint: str,
|
||||
req_data: dict, request_id: str):
|
||||
"""
|
||||
Asynchronously stream response from a service using a client from the pool.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
|
||||
async with client_info['client'].stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Get the next prefill client in round-robin fashion
|
||||
prefill_client_info = get_next_client(request.app, 'prefill')
|
||||
|
||||
# Send request to prefill service
|
||||
response = await send_request_to_service(prefill_client_info, api,
|
||||
req_data, request_id)
|
||||
|
||||
# Extract the needed fields
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
# Get the next decode client in round-robin fashion
|
||||
decode_client_info = get_next_client(request.app, 'decode')
|
||||
|
||||
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(decode_client_info,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
"""Simple endpoint to check if the server is running."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(app.state.prefill_clients),
|
||||
"decode_instances": len(app.state.decode_clients)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
Reference in New Issue
Block a user