mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it? This PR is aimed to fix the recomputing out of memory bug in decode instance. When recomputing happens in decode, kv cache usage may exceed the pre-allocated memory, and it will cause OOM. So we propose a new scheduling strategy, when decode instance cannot allocate new block for running requests, we will stop the request that will be preempted. These stopped request will be recognied by proxy, and they will be send to prefill instance again to calculate kvc and then direct to decode instance. This is a temporary plan to fix the bug. The long-term stratege is to use CPU offload in decode instance. ### Does this PR introduce _any_ user-facing change? An extra ascend configuration option **-- recompute_scheduler_enable = True** is added to enable this strategy. The default value is False ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
652 lines
25 KiB
Python
652 lines
25 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Tutorial: Using the Load Balance Proxy Server Example
|
|
#
|
|
# This proxy server is designed to distribute requests between multiple
|
|
# "prefiller" and "decoder" backend servers for large language model inference.
|
|
# It is useful for scaling out inference workloads and balancing load across
|
|
# multiple backend instances.
|
|
#
|
|
# Features:
|
|
# - Load balances requests to multiple prefiller and decoder servers.
|
|
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
|
# - Streams responses from backend servers to clients.
|
|
#
|
|
# Prerequisites:
|
|
# - Python 3.8+
|
|
# - Install dependencies:
|
|
# pip install fastapi httpx uvicorn vllm
|
|
#
|
|
# Step 1: Start Your Backend Servers
|
|
# ----------------------------------
|
|
# You need to have at least one prefiller and one decoder backend running.
|
|
# These can be mock servers or actual vLLM servers.
|
|
#
|
|
# For testing, you can use the provided mock server:
|
|
#
|
|
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
|
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
|
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
|
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
|
#
|
|
# Step 2: Start the Proxy Server
|
|
# ------------------------------
|
|
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
|
#
|
|
# python load_balance_proxy_server_example.py \
|
|
# --host 0.0.0.0 --port 9000 \
|
|
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
|
# --prefiller-ports 8100 8101 \
|
|
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
|
# --decoder-ports 8200 8201
|
|
#
|
|
# This will start the proxy on port 9000, load balancing between two prefiller
|
|
# and two decoder servers.
|
|
#
|
|
# Step 3: Send a Request to the Proxy
|
|
# -----------------------------------
|
|
# You can now send OpenAI-compatible requests to the proxy. For example:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "prompt": "The quick brown fox jumps over the lazy dog",
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Or for chat completions:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/chat/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "messages": [{"role": "user", "content": "Hello!"}],
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Step 4: Health Check
|
|
# --------------------
|
|
# To check if the proxy is running and see how many backend instances are
|
|
# connected, use:
|
|
#
|
|
# curl http://localhost:9000/healthcheck
|
|
#
|
|
# This will return a JSON object with the status and the number of prefiller
|
|
# and decoder instances.
|
|
#
|
|
# Notes:
|
|
# - You can scale the number of prefiller and decoder servers as needed.
|
|
# - The proxy will round-robin requests to balance load.
|
|
# - For production, ensure your backend servers are robust and secure.
|
|
#
|
|
# For more details, see the code and comments in this file.
|
|
|
|
import argparse
|
|
import asyncio
|
|
import functools
|
|
import heapq
|
|
import json
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, List
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Add uvloop for faster event loop if available
|
|
try:
|
|
import uvloop
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class ServerState:
|
|
|
|
def __init__(self, host, port):
|
|
self.host = host
|
|
self.port = port
|
|
self.url = f'http://{host}:{port}/v1'
|
|
self.client = httpx.AsyncClient(timeout=None,
|
|
base_url=self.url,
|
|
limits=httpx.Limits(
|
|
max_connections=100000,
|
|
max_keepalive_connections=100000))
|
|
self.active_tokens = 0
|
|
self.active_kv_cache = 0 # Only for prefiller
|
|
self.active_requests = 0 # Number of active requests
|
|
self.aborted_requests = set() # Track aborted requests
|
|
# Removed individual server lock - will use global locks instead
|
|
|
|
|
|
class ProxyState:
|
|
|
|
def __init__(self, prefiller_instances, decoder_instances):
|
|
self.prefillers: List[ServerState] = [
|
|
ServerState(h, p) for h, p in prefiller_instances
|
|
]
|
|
self.decoders: List[ServerState] = [
|
|
ServerState(h, p) for h, p in decoder_instances
|
|
]
|
|
self.req_to_prefiller = {}
|
|
self.req_id_lock = asyncio.Lock()
|
|
# Removed selection locks - no longer needed for synchronous methods
|
|
|
|
# Initialize priority queues for efficient server selection
|
|
# Each entry is (priority_score, server_index, server_reference)
|
|
# Lower priority score = higher priority (less loaded)
|
|
self.prefiller_heap = [(0, i, server)
|
|
for i, server in enumerate(self.prefillers)]
|
|
self.decoder_heap = [(0, i, server)
|
|
for i, server in enumerate(self.decoders)]
|
|
heapq.heapify(self.prefiller_heap)
|
|
heapq.heapify(self.decoder_heap)
|
|
|
|
def _update_prefiller_priority(self, server_idx: int):
|
|
"""Update the priority of a prefiller server in the heap."""
|
|
server = self.prefillers[server_idx]
|
|
# Priority based on active_tokens and active_kv_cache
|
|
priority = server.active_tokens + server.active_kv_cache * 0.3
|
|
# Remove old entry and add new one
|
|
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
|
if i != server_idx]
|
|
heapq.heappush(self.prefiller_heap,
|
|
(priority, server_idx, server)) # type: ignore
|
|
|
|
def _update_decoder_priority(self, server_idx: int):
|
|
"""Update the priority of a decoder server in the heap."""
|
|
server = self.decoders[server_idx]
|
|
priority = server.active_tokens
|
|
# Remove old entry and add new one
|
|
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
|
if i != server_idx]
|
|
heapq.heappush(self.decoder_heap,
|
|
(priority, server_idx, server)) # type: ignore
|
|
|
|
def abort_prefiller_request(self, server_idx: int,
|
|
request_id): # Changed to synchronous
|
|
"""
|
|
Mark a request as aborted. This will helps to release kv cache in
|
|
prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
|
|
|
def aquire_aborted_prefiller_requests(
|
|
self, server_idx: int): # Changed to synchronous
|
|
"""
|
|
Get the set of aborted requests and clear it.
|
|
This is used to release kv cache in prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
|
self.prefillers[server_idx].aborted_requests.clear()
|
|
return aborted_requests
|
|
|
|
async def next_req_id(self):
|
|
async with self.req_id_lock:
|
|
return str(uuid.uuid4())
|
|
|
|
def select_prefiller(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.prefiller_heap:
|
|
raise RuntimeError("No prefiller servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.prefillers[chosen].active_tokens += token_count
|
|
self.prefillers[chosen].active_kv_cache += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_prefiller_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
self.prefillers[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
if self.prefillers[idx].active_kv_cache > 0:
|
|
self.prefillers[idx].active_kv_cache -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def select_decoder(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.decoder_heap:
|
|
raise RuntimeError("No decoder servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.decoders[chosen].active_tokens += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_decoder_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_decoder(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
self.decoders[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_decoder_priority(idx)
|
|
|
|
# Omni_infer's calculate_input_scores function
|
|
def calculate_prefill_scores(self, request_length: int) -> float:
|
|
length_score = request_length / 4.0
|
|
input_score = length_score * 0.0345 + 120.0745
|
|
return input_score
|
|
|
|
def calculate_decode_scores(self, request_length: int) -> float:
|
|
return request_length
|
|
|
|
|
|
proxy_state = None
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--prefiller-hosts",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--prefiller-ports",
|
|
type=int,
|
|
nargs="+",
|
|
default=[8001])
|
|
parser.add_argument("--decoder-hosts",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
|
parser.add_argument("--max-retries",
|
|
type=int,
|
|
default=3,
|
|
help="Maximum number of retries for HTTP requests")
|
|
parser.add_argument(
|
|
"--retry-delay",
|
|
type=float,
|
|
default=0.001,
|
|
help="Base delay (seconds) for exponential backoff retries")
|
|
args = parser.parse_args()
|
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
|
raise ValueError(
|
|
"Number of prefiller hosts must match number of prefiller ports")
|
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
|
raise ValueError(
|
|
"Number of decoder hosts must match number of decoder ports")
|
|
args.prefiller_instances = list(
|
|
zip(args.prefiller_hosts, args.prefiller_ports))
|
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
|
return args
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global proxy_state
|
|
proxy_state = ProxyState(global_args.prefiller_instances,
|
|
global_args.decoder_instances)
|
|
print(
|
|
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
|
)
|
|
yield
|
|
for p in proxy_state.prefillers:
|
|
await p.client.aclose()
|
|
for d in proxy_state.decoders:
|
|
await d.client.aclose()
|
|
|
|
|
|
async def listen_for_disconnect(request: Request) -> None:
|
|
"""Return if a disconnect message is received"""
|
|
while True:
|
|
message = await request.receive()
|
|
if message["type"] == "http.disconnect":
|
|
break
|
|
|
|
|
|
def with_cancellation(handler_func):
|
|
|
|
@functools.wraps(handler_func)
|
|
async def wrapper(*args, **kwargs):
|
|
request = kwargs["request"]
|
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
|
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
for task in pending:
|
|
task.cancel()
|
|
if handler_task in done:
|
|
return handler_task.result()
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
async def send_request_to_service(client: httpx.AsyncClient,
|
|
prefiller_id: int,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2):
|
|
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
|
prefiller_id)
|
|
req_data = req_data.copy()
|
|
req_data['kv_transfer_params'] = {
|
|
"do_remote_decode": True,
|
|
"do_remote_prefill": False,
|
|
"remote_engine_id": None,
|
|
"remote_block_ids": None,
|
|
"remote_host": None,
|
|
"remote_port": None,
|
|
"aborted_request": list(aborted_requests),
|
|
}
|
|
req_data["stream"] = False
|
|
req_data["max_tokens"] = 1
|
|
req_data["min_tokens"] = 1
|
|
if "stream_options" in req_data:
|
|
del req_data["stream_options"]
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
last_exc = None
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = await client.post(endpoint,
|
|
json=req_data,
|
|
headers=headers)
|
|
response.raise_for_status()
|
|
return response
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
|
last_exc = e
|
|
if attempt < max_retries:
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for {endpoint}.")
|
|
raise last_exc
|
|
|
|
|
|
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2):
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
async with client.stream("POST",
|
|
endpoint,
|
|
json=req_data,
|
|
headers=headers) as response:
|
|
response.raise_for_status()
|
|
first_chunk_sent = False
|
|
async for chunk in response.aiter_bytes():
|
|
first_chunk_sent = True
|
|
yield chunk
|
|
return # Success, exit after streaming
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
if attempt < max_retries:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
|
)
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
# If any chunk has been sent, do not retry, just log and drop
|
|
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
|
logger.error(
|
|
f"Streaming to client interrupted after response started: {str(e)}"
|
|
)
|
|
return
|
|
else:
|
|
if attempt < max_retries:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
|
)
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
)
|
|
raise e
|
|
|
|
|
|
async def _handle_select_instance(api: str, req_data: Any,
|
|
request_length: int):
|
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
|
logger.debug(
|
|
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
|
)
|
|
request_id = await proxy_state.next_req_id()
|
|
# Select prefiller
|
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
|
prefiller = proxy_state.prefillers[prefiller_idx]
|
|
# Send request to prefiller
|
|
response = await send_request_to_service(
|
|
prefiller.client,
|
|
prefiller_idx,
|
|
api,
|
|
req_data,
|
|
request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay)
|
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
|
response_json = response.json()
|
|
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
|
if kv_transfer_params:
|
|
req_data["kv_transfer_params"] = kv_transfer_params
|
|
# Select decoder
|
|
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
|
logger.debug("Decoder score: %f", decoder_score)
|
|
# Use the prefiller's kv_transfer_params to select decoder
|
|
decoder_idx = proxy_state.select_decoder(decoder_score)
|
|
decoder = proxy_state.decoders[decoder_idx]
|
|
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
|
return InstanceInfo(request_id=request_id,
|
|
prefiller_idx=prefiller_idx,
|
|
prefiller_score=prefiller_score,
|
|
prefiller=prefiller,
|
|
decoder=decoder,
|
|
decoder_idx=decoder_idx,
|
|
decoder_score=decoder_score)
|
|
|
|
|
|
@dataclass
|
|
class InstanceInfo:
|
|
request_id: str
|
|
prefiller_idx: int
|
|
prefiller_score: float
|
|
prefiller: ServerState
|
|
decoder_idx: int
|
|
decoder_score: float
|
|
decoder: ServerState
|
|
|
|
|
|
async def _handle_completions(api: str, request: Request):
|
|
try:
|
|
req_data = await request.json()
|
|
req_body = await request.body()
|
|
request_length = len(req_body)
|
|
instance_info = await _handle_select_instance(api, req_data,
|
|
request_length)
|
|
stream_flag = bool(req_data.get("stream", False))
|
|
chat_flag = "messages" in req_data
|
|
|
|
if "prompt" in req_data:
|
|
origin_prompt = req_data["prompt"]
|
|
elif chat_flag:
|
|
messages = req_data["messages"]
|
|
origin_prompt = messages[0].get("content", "")
|
|
else:
|
|
origin_prompt = ""
|
|
# refer to vLLM sampling_params: max_token default value
|
|
origin_max_tokens = req_data.get("max_tokens", 16)
|
|
|
|
async def generate_stream():
|
|
nonlocal instance_info
|
|
generated_token = ""
|
|
released_kv = False
|
|
retry_count = 0
|
|
retry = True
|
|
completion_tokens = 0
|
|
# Only one await per chunk, minimal logic in loop
|
|
try:
|
|
while retry:
|
|
retry = False
|
|
async for chunk in stream_service_response_with_retry(
|
|
instance_info.decoder.client,
|
|
api,
|
|
req_data,
|
|
request_id=instance_info.request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay):
|
|
if not released_kv and chunk:
|
|
proxy_state.release_prefiller_kv(
|
|
instance_info.prefiller_idx,
|
|
instance_info.prefiller_score)
|
|
released_kv = True
|
|
chunk_str = chunk.decode("utf-8").strip()
|
|
if not chunk_str:
|
|
continue
|
|
if chunk_str.startswith("data: "):
|
|
chunk_str = chunk_str[len("data: "):]
|
|
try:
|
|
chunk_json = json.loads(chunk_str)
|
|
except json.JSONDecodeError:
|
|
# if chunk is [done], skip it.
|
|
logger.warning(
|
|
f"Skipping chunk: {chunk_str}")
|
|
yield chunk
|
|
continue
|
|
choices = chunk_json.get("choices", [])
|
|
if not choices:
|
|
yield chunk
|
|
continue
|
|
|
|
choice = choices[0]
|
|
delta = choice.get("delta") or {}
|
|
message = choice.get("message") or {}
|
|
content = (
|
|
delta.get("content")
|
|
or message.get("content")
|
|
or choice.get("text")
|
|
or ""
|
|
)
|
|
generated_token += content
|
|
|
|
stop_reason = choice.get(
|
|
"stop_reason")
|
|
usage = chunk_json.get("usage", {})
|
|
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
|
(completion_tokens + usage.get("completion_tokens"))
|
|
if stop_reason == "recomputed":
|
|
retry = True
|
|
retry_count += 1
|
|
if chat_flag:
|
|
messages[0][
|
|
"content"] = origin_prompt + generated_token
|
|
else:
|
|
req_data[
|
|
"prompt"] = origin_prompt + generated_token
|
|
req_data[
|
|
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
|
tmp_request_length = len(
|
|
json.dumps(req_data).encode("utf-8"))
|
|
instance_info = await _handle_select_instance(
|
|
api, req_data, tmp_request_length)
|
|
break
|
|
if retry_count > 0 and not stream_flag:
|
|
if chat_flag:
|
|
choice["message"][
|
|
"content"] = generated_token
|
|
else:
|
|
choice["text"] = generated_token
|
|
chunk = json.dumps(chunk_json).encode("utf-8")
|
|
yield chunk
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
|
)
|
|
proxy_state.abort_prefiller_request(
|
|
instance_info.prefiller_idx, instance_info.request_id)
|
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
|
instance_info.prefiller_score)
|
|
|
|
# After streaming done, release tokens
|
|
proxy_state.release_decoder(instance_info.decoder_idx,
|
|
instance_info.decoder_score)
|
|
|
|
return StreamingResponse(generate_stream(),
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
import traceback
|
|
exc_info = sys.exc_info()
|
|
print("Error occurred in disagg prefill proxy server"
|
|
f" - {api} endpoint")
|
|
print(e)
|
|
print("".join(traceback.format_exception(*exc_info)))
|
|
raise
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
@with_cancellation
|
|
async def handle_completions(request: Request):
|
|
return await _handle_completions("/completions", request)
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@with_cancellation
|
|
async def handle_chat_completions(request: Request):
|
|
return await _handle_completions("/chat/completions", request)
|
|
|
|
|
|
@app.get("/healthcheck")
|
|
async def healthcheck():
|
|
return {
|
|
"status": "ok",
|
|
"prefill_instances": len(proxy_state.prefillers),
|
|
"decode_instances": len(proxy_state.decoders)
|
|
}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
global global_args
|
|
global_args = parse_args()
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|