Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2025-03-22 13:11:50 -04:00
parent b89d89f456
commit a8a621e419
3 changed files with 305 additions and 283 deletions

View File

@ -1,283 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import signal
import sys
import traceback
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Union
import uvicorn
import uvloop
import zmq
import zmq.asyncio
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
ZmqMsgResponse)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disagg_connector')
TIME_OUT = 5
X_REQUEST_ID_KEY = "X-Request-Id"
CONTENT_TYPE_STREAM = "text/event-stream"
# communication between output handlers and execute_task_async
request_queues: dict[str, asyncio.Queue]
async def log_stats(request_queues: dict[str, asyncio.Queue]):
while True:
logger.info("Running requests: %d", len(request_queues))
await asyncio.sleep(10)
# create async socket use ZMQ_DEALER
async def create_socket(url: str,
zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket:
socket = zmqctx.socket(zmq.DEALER)
identity = f"connector-{uuid.uuid4()}"
# unlimited HWM
hwm_limit = 0
socket.setsockopt(zmq.IDENTITY, identity.encode())
socket.setsockopt(zmq.SNDHWM, hwm_limit)
socket.setsockopt(zmq.RCVHWM, hwm_limit)
socket.connect(url)
logger.info("%s started at %s", identity, url)
return socket
@asynccontextmanager
async def lifespan(app: FastAPI):
# create socket pool with prefill and decode
logger.info("start connect zmq server")
app.state.zmqctx = zmq.asyncio.Context()
app.state.prefill_socket = await create_socket(app.state.prefill_addr,
zmqctx=app.state.zmqctx)
logger.info("success create_socke sockets_prefill")
app.state.decode_socket = await create_socket(app.state.decode_addr,
zmqctx=app.state.zmqctx)
logger.info("success create_socket sockets_decode")
global request_queues
request_queues = {}
asyncio.create_task(prefill_handler(app.state.prefill_socket))
asyncio.create_task(decode_handler(app.state.decode_socket))
asyncio.create_task(log_stats(request_queues))
yield
## close zmq context
logger.info("shutdown disagg connector")
logger.info("term zmqctx")
app.state.zmqctx.destroy(linger=0)
app = FastAPI(lifespan=lifespan)
@app.post('/v1/completions')
async def completions(request: CompletionRequest, raw_request: Request,
background_tasks: BackgroundTasks):
try:
# Add the X-Request-Id header to the raw headers list
header = dict(raw_request.headers)
request_id = header.get(X_REQUEST_ID_KEY)
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
if request_id is None:
request_id = str(uuid.uuid4())
logger.debug("add X-Request-Id: %s", request_id)
logger.debug("X-Request-Id is: %s", request_id)
request_queues[request_id] = queue
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
type="completions",
body=request)
logger.info("Received request_id: %s, request: %s, header: %s",
request_id, zmq_msg_request.model_dump_json(), header)
original_max_tokens = request.max_tokens
# change max_tokens = 1 to let it only do prefill
request.max_tokens = 1
# finish prefill
try:
prefill_response = await prefill(zmq_msg_request)
if isinstance(prefill_response, JSONResponse
) and prefill_response.status_code != HTTPStatus.OK:
return prefill_response
logger.debug("finish prefill start decode")
request.max_tokens = original_max_tokens
response = await decode(zmq_msg_request)
logger.debug("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse(
{"error": {
"message": str(e)
}},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
return response
except Exception as e:
exc_info = sys.exc_info()
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))
response = JSONResponse({"error": {
"message": str(e)
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
return response
finally:
if request_id is not None:
background_tasks.add_task(cleanup_request_id, request_id)
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
while True:
try:
[body] = await socket.recv_multipart()
response = ZmqMsgResponse.model_validate_json(body)
request_id = response.request_id
logger.debug("%s socket received result: %s", scene,
response.model_dump_json())
if request_id in request_queues:
request_queues[request_id].put_nowait(response)
else:
logger.debug(
"%s socket received but request_id not found discard: %s",
scene, request_id)
except Exception as e:
logger.error(traceback.format_exc())
logger.error("%s handler error: %s", scene, e)
# prefill handler
async def prefill_handler(prefill_socket: zmq.asyncio.Socket):
await socket_recv_handler(prefill_socket, "prefill")
# decode handler
async def decode_handler(decode_socket: zmq.asyncio.Socket):
await socket_recv_handler(decode_socket, "decode")
# select a socket and execute task
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
try:
request_id = zmq_msg_request.request_id
requestBody = zmq_msg_request.model_dump_json()
logger.debug("Sending requestBody: %s", requestBody)
socket.send_multipart([requestBody.encode()])
logger.debug("Sent end")
queue = request_queues[request_id]
while True:
logger.debug("Waiting for reply")
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
queue.get(), TIME_OUT)
logger.debug("Received result: %s",
zmq_msg_response.model_dump_json())
yield zmq_msg_response
if zmq_msg_response.stop:
logger.debug("Received stop: %s", zmq_msg_response.stop)
break
except asyncio.TimeoutError:
logger.error(traceback.format_exc())
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
finally:
logger.debug("request_id: %s, execute_task_async end", request_id)
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
logger.debug("start prefill")
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
async for res in generator:
logger.debug("res: %s", res)
if res.body_type == "response":
return JSONResponse(res.body)
return True
async def generate_stream_response(
fisrt_reply: str,
generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]:
yield fisrt_reply
async for reply in generator:
yield reply.body
async def decode(
zmq_msg_request: ZmqMsgRequest
) -> Union[JSONResponse, StreamingResponse]:
logger.debug("start decode")
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
async for res in generator:
logger.debug("res: %s", res)
if res.body_type == "response":
return JSONResponse(res.body)
else:
return StreamingResponse(generate_stream_response(
res.body, generator),
media_type=CONTENT_TYPE_STREAM)
# If the generator is empty, return a default error response
logger.error("No response received from generator")
return JSONResponse({"error": "No response received from generator"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
def cleanup_request_id(request_id: str):
if request_id in request_queues:
logger.info("del request_id: %s, decode finished", request_id)
del request_queues[request_id]
async def run_disagg_connector(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)
app.state.port = args.port
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
app.state.decode_addr = f"ipc://{args.decode_addr}"
logger.info(
"start connect prefill_addr: %s decode_addr: %s "
"zmq server fastapi port: %s", app.state.prefill_addr,
app.state.decode_addr, app.state.port)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
# init uvicorn server
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
# NOTE(simon):
# This section should be sync with vllm/entrypoints/cli/connect.py for CLI
# entrypoints.
parser = FlexibleArgumentParser(description="vLLM disagg connect server.")
parser.add_argument("--port",
type=int,
default=8001,
help="The fastapi server port default 8001")
# security concern only support ipc now
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
args = parser.parse_args()
uvloop.run(run_disagg_connector(args))

View File

@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import msgspec
from collections.abc import AsyncGenerator
from typing import Dict, Mapping, Optional
import uvicorn
import uvloop
import zmq
import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
# NOTE FOR DEVELOPERS:
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
class PDRequest(msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt: str
sampling_params: SamplingParams
# TODO: support multimodal inputs.
class PDResponse(msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
success: bool
delta_text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[str] = None
logprobs = None # TODO
class PDEngine:
"""
PDEngine:
Equiavlent of AsyncLLM for P/D. Assumes there is
a Prefill and Decode service already running.
* TODO: actually handle errors and failure.
* TODO: support more than just text input.
* TODO: move under vllm/v1/engine one past prototype.
"""
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
# Request queues.
self.queues: Dict[str, asyncio.Queue] = {}
# Serialization encoder.
self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication..
self.ctx = zmq.asyncio.Context()
self.to_decode = make_zmq_socket(
self.ctx, f"{decode_addr}", zmq.constants.PUSH)
self.to_prefill = make_zmq_socket(
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
self.connector_addr = connector_addr
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
def shutdown(self):
if (ctx := self.ctx) is not None:
ctx.destroy(linger=0)
if (task := self.log_running) is not None:
task.cancel()
if (task := self.output_handler) is not None:
task.cancel()
async def _run_log_running(self):
logger.info("Running requests: %d", len(self.queues))
await asyncio.sleep(10.)
async def _run_output_handler(self, socket: zmq.asyncio.Socket):
"""
Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
socket = make_zmq_socket(
self.ctx, self.connector_addr, zmq.constants.PULL)
while True:
reponse_bytes = await socket.recv().buffer
response = decoder.decode(reponse_bytes)
self.queues[response.request_id].put_nowait(response)
except:
# TODO: actually handle failure and shutdown.
raise
finally:
if socket is not None:
socket.close(linger=0)
async def _prefill(self,
request: PDRequest,
q: asyncio.Queue[PDResponse]) -> PDResponse:
# Send request to the prefill instance.
req_bytes = self.encoder(request)
await self.to_prefill.send(req_bytes, copy=False)
# Wait for the prefill to be done.
response = await q.get()
assert response.request_id == request.request_id
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.")
return response
async def _decode(self,
request: PDRequest,
q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]:
# Send request to the decode instance.
req_bytes = self.encoder(request)
await self.to_decode.send(req_bytes, copy=False)
# Iterate response queue and yield each response to caller..
finished = False
while not finished:
response = await q.get()
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.")
finished = response.finish_reason is not None
yield response
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[PDResponse]:
# Start loops on first request.
if self.output_handler is None:
self.output_handler = asyncio.create_task(self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: expand to suppo
if not isinstance(prompt, str):
raise ValueError("We currently only support text inputs!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the prefill (max_tokens=1).
original_max_tokens = sampling_params.max_tokens
request = PDRequest(request_id, prompt, sampling_params)
request.sampling_params.max_tokens = 1
response = await self._prefill(request, q)
yield response
# (2) Perform the decodes (original tokens).
request.sampling_params.max_tokens = original_max_tokens
async for response in self._decode(request, q):
yield response
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
raise NotImplementedError
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError
async def abort(self, request_id: str) -> None:
raise NotImplementedError
async def get_model_config(self) -> ModelConfig:
raise NotImplementedError
async def get_decoding_config(self) -> DecodingConfig:
raise NotImplementedError
async def get_input_preprocessor(self) -> InputPreprocessor:
raise NotImplementedError
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
raise NotImplementedError
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
pass
async def start_profile(self) -> None:
raise NotImplementedError
async def stop_profile(self) -> None:
raise NotImplementedError
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
raise NotImplementedError
async def sleep(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up(self) -> None:
raise NotImplementedError
async def is_sleeping(self) -> bool:
False
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError
async def run_disagg_connector(args, **uvicorn_kwargs):
logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs)
# NOTE FOR DEVELOPERS: when we shift this to TCP, we must
# ensure that the serialization is not pickle based to
# avoid RCE issues from untrusted users!!!
app.state.port = args.port
app.state.connector_addr = f"ipc://{args.connector_addr}"
app.state.decode_addr = f"ipc://{args.decode_addr}"
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
# init uvicorn server
config = uvicorn.Config(app, host=args.post, port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--connector-addr",
type=str,
required=True,
help="The zmq ipc connector address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))