mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
@ -1,283 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Union
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
|
||||
ZmqMsgResponse)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.disagg_connector')
|
||||
|
||||
TIME_OUT = 5
|
||||
X_REQUEST_ID_KEY = "X-Request-Id"
|
||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
||||
|
||||
# communication between output handlers and execute_task_async
|
||||
request_queues: dict[str, asyncio.Queue]
|
||||
|
||||
|
||||
async def log_stats(request_queues: dict[str, asyncio.Queue]):
|
||||
while True:
|
||||
logger.info("Running requests: %d", len(request_queues))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
# create async socket use ZMQ_DEALER
|
||||
async def create_socket(url: str,
|
||||
zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket:
|
||||
socket = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"connector-{uuid.uuid4()}"
|
||||
# unlimited HWM
|
||||
hwm_limit = 0
|
||||
socket.setsockopt(zmq.IDENTITY, identity.encode())
|
||||
socket.setsockopt(zmq.SNDHWM, hwm_limit)
|
||||
socket.setsockopt(zmq.RCVHWM, hwm_limit)
|
||||
socket.connect(url)
|
||||
logger.info("%s started at %s", identity, url)
|
||||
return socket
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# create socket pool with prefill and decode
|
||||
logger.info("start connect zmq server")
|
||||
app.state.zmqctx = zmq.asyncio.Context()
|
||||
app.state.prefill_socket = await create_socket(app.state.prefill_addr,
|
||||
zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socke sockets_prefill")
|
||||
app.state.decode_socket = await create_socket(app.state.decode_addr,
|
||||
zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket sockets_decode")
|
||||
global request_queues
|
||||
request_queues = {}
|
||||
asyncio.create_task(prefill_handler(app.state.prefill_socket))
|
||||
asyncio.create_task(decode_handler(app.state.decode_socket))
|
||||
asyncio.create_task(log_stats(request_queues))
|
||||
yield
|
||||
## close zmq context
|
||||
logger.info("shutdown disagg connector")
|
||||
logger.info("term zmqctx")
|
||||
app.state.zmqctx.destroy(linger=0)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post('/v1/completions')
|
||||
async def completions(request: CompletionRequest, raw_request: Request,
|
||||
background_tasks: BackgroundTasks):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
header = dict(raw_request.headers)
|
||||
request_id = header.get(X_REQUEST_ID_KEY)
|
||||
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
logger.debug("add X-Request-Id: %s", request_id)
|
||||
logger.debug("X-Request-Id is: %s", request_id)
|
||||
request_queues[request_id] = queue
|
||||
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
|
||||
type="completions",
|
||||
body=request)
|
||||
logger.info("Received request_id: %s, request: %s, header: %s",
|
||||
request_id, zmq_msg_request.model_dump_json(), header)
|
||||
original_max_tokens = request.max_tokens
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
request.max_tokens = 1
|
||||
# finish prefill
|
||||
try:
|
||||
prefill_response = await prefill(zmq_msg_request)
|
||||
if isinstance(prefill_response, JSONResponse
|
||||
) and prefill_response.status_code != HTTPStatus.OK:
|
||||
return prefill_response
|
||||
logger.debug("finish prefill start decode")
|
||||
request.max_tokens = original_max_tokens
|
||||
response = await decode(zmq_msg_request)
|
||||
logger.debug("finish decode")
|
||||
except Exception as e:
|
||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||
e)
|
||||
response = JSONResponse(
|
||||
{"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
exc_info = sys.exc_info()
|
||||
logger.error("Error occurred in disagg prefill proxy server")
|
||||
logger.error(e)
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
response = JSONResponse({"error": {
|
||||
"message": str(e)
|
||||
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return response
|
||||
finally:
|
||||
if request_id is not None:
|
||||
background_tasks.add_task(cleanup_request_id, request_id)
|
||||
|
||||
|
||||
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
||||
while True:
|
||||
try:
|
||||
[body] = await socket.recv_multipart()
|
||||
response = ZmqMsgResponse.model_validate_json(body)
|
||||
request_id = response.request_id
|
||||
logger.debug("%s socket received result: %s", scene,
|
||||
response.model_dump_json())
|
||||
if request_id in request_queues:
|
||||
request_queues[request_id].put_nowait(response)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s socket received but request_id not found discard: %s",
|
||||
scene, request_id)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("%s handler error: %s", scene, e)
|
||||
|
||||
|
||||
# prefill handler
|
||||
async def prefill_handler(prefill_socket: zmq.asyncio.Socket):
|
||||
await socket_recv_handler(prefill_socket, "prefill")
|
||||
|
||||
|
||||
# decode handler
|
||||
async def decode_handler(decode_socket: zmq.asyncio.Socket):
|
||||
await socket_recv_handler(decode_socket, "decode")
|
||||
|
||||
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
try:
|
||||
request_id = zmq_msg_request.request_id
|
||||
requestBody = zmq_msg_request.model_dump_json()
|
||||
logger.debug("Sending requestBody: %s", requestBody)
|
||||
socket.send_multipart([requestBody.encode()])
|
||||
logger.debug("Sent end")
|
||||
queue = request_queues[request_id]
|
||||
while True:
|
||||
logger.debug("Waiting for reply")
|
||||
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
|
||||
queue.get(), TIME_OUT)
|
||||
logger.debug("Received result: %s",
|
||||
zmq_msg_response.model_dump_json())
|
||||
yield zmq_msg_response
|
||||
if zmq_msg_response.stop:
|
||||
logger.debug("Received stop: %s", zmq_msg_response.stop)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(traceback.format_exc())
|
||||
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
|
||||
finally:
|
||||
logger.debug("request_id: %s, execute_task_async end", request_id)
|
||||
|
||||
|
||||
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
|
||||
logger.debug("start prefill")
|
||||
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
|
||||
async for res in generator:
|
||||
logger.debug("res: %s", res)
|
||||
if res.body_type == "response":
|
||||
return JSONResponse(res.body)
|
||||
return True
|
||||
|
||||
|
||||
async def generate_stream_response(
|
||||
fisrt_reply: str,
|
||||
generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]:
|
||||
yield fisrt_reply
|
||||
async for reply in generator:
|
||||
yield reply.body
|
||||
|
||||
|
||||
async def decode(
|
||||
zmq_msg_request: ZmqMsgRequest
|
||||
) -> Union[JSONResponse, StreamingResponse]:
|
||||
logger.debug("start decode")
|
||||
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
|
||||
|
||||
async for res in generator:
|
||||
logger.debug("res: %s", res)
|
||||
if res.body_type == "response":
|
||||
return JSONResponse(res.body)
|
||||
else:
|
||||
return StreamingResponse(generate_stream_response(
|
||||
res.body, generator),
|
||||
media_type=CONTENT_TYPE_STREAM)
|
||||
|
||||
# If the generator is empty, return a default error response
|
||||
logger.error("No response received from generator")
|
||||
return JSONResponse({"error": "No response received from generator"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
def cleanup_request_id(request_id: str):
|
||||
if request_id in request_queues:
|
||||
logger.info("del request_id: %s, decode finished", request_id)
|
||||
del request_queues[request_id]
|
||||
|
||||
|
||||
async def run_disagg_connector(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
logger.info(args.prefill_addr)
|
||||
app.state.port = args.port
|
||||
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
||||
logger.info(
|
||||
"start connect prefill_addr: %s decode_addr: %s "
|
||||
"zmq server fastapi port: %s", app.state.prefill_addr,
|
||||
app.state.decode_addr, app.state.port)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
# init uvicorn server
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be sync with vllm/entrypoints/cli/connect.py for CLI
|
||||
# entrypoints.
|
||||
parser = FlexibleArgumentParser(description="vLLM disagg connect server.")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="The fastapi server port default 8001")
|
||||
# security concern only support ipc now
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
uvloop.run(run_disagg_connector(args))
|
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
305
vllm/entrypoints/disaggregated/connector.py
Normal file
305
vllm/entrypoints/disaggregated/connector.py
Normal file
@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import msgspec
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Dict, Mapping, Optional
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 32000
|
||||
|
||||
# NOTE FOR DEVELOPERS:
|
||||
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
||||
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
||||
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
||||
|
||||
class PDRequest(msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
prompt: str
|
||||
sampling_params: SamplingParams
|
||||
# TODO: support multimodal inputs.
|
||||
|
||||
class PDResponse(msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
success: bool
|
||||
delta_text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[str] = None
|
||||
logprobs = None # TODO
|
||||
|
||||
|
||||
class PDEngine:
|
||||
"""
|
||||
PDEngine:
|
||||
Equiavlent of AsyncLLM for P/D. Assumes there is
|
||||
a Prefill and Decode service already running.
|
||||
|
||||
* TODO: actually handle errors and failure.
|
||||
* TODO: support more than just text input.
|
||||
* TODO: move under vllm/v1/engine one past prototype.
|
||||
"""
|
||||
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
|
||||
# Request queues.
|
||||
self.queues: Dict[str, asyncio.Queue] = {}
|
||||
|
||||
# Serialization encoder.
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
|
||||
# ZMQ communication..
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.to_decode = make_zmq_socket(
|
||||
self.ctx, f"{decode_addr}", zmq.constants.PUSH)
|
||||
self.to_prefill = make_zmq_socket(
|
||||
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
|
||||
self.connector_addr = connector_addr
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
self.log_running: Optional[asyncio.Task] = None
|
||||
|
||||
def shutdown(self):
|
||||
if (ctx := self.ctx) is not None:
|
||||
ctx.destroy(linger=0)
|
||||
if (task := self.log_running) is not None:
|
||||
task.cancel()
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
async def _run_log_running(self):
|
||||
logger.info("Running requests: %d", len(self.queues))
|
||||
await asyncio.sleep(10.)
|
||||
|
||||
async def _run_output_handler(self, socket: zmq.asyncio.Socket):
|
||||
"""
|
||||
Pull responses from Decode + Prefill engines and
|
||||
distribute back to the generate() tasks.
|
||||
"""
|
||||
decoder = msgspec.msgpack.Decoder(PDResponse)
|
||||
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
try:
|
||||
socket = make_zmq_socket(
|
||||
self.ctx, self.connector_addr, zmq.constants.PULL)
|
||||
|
||||
while True:
|
||||
reponse_bytes = await socket.recv().buffer
|
||||
response = decoder.decode(reponse_bytes)
|
||||
self.queues[response.request_id].put_nowait(response)
|
||||
except:
|
||||
# TODO: actually handle failure and shutdown.
|
||||
raise
|
||||
finally:
|
||||
if socket is not None:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _prefill(self,
|
||||
request: PDRequest,
|
||||
q: asyncio.Queue[PDResponse]) -> PDResponse:
|
||||
# Send request to the prefill instance.
|
||||
req_bytes = self.encoder(request)
|
||||
await self.to_prefill.send(req_bytes, copy=False)
|
||||
|
||||
# Wait for the prefill to be done.
|
||||
response = await q.get()
|
||||
assert response.request_id == request.request_id
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Prefill Request.")
|
||||
|
||||
return response
|
||||
|
||||
async def _decode(self,
|
||||
request: PDRequest,
|
||||
q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]:
|
||||
# Send request to the decode instance.
|
||||
req_bytes = self.encoder(request)
|
||||
await self.to_decode.send(req_bytes, copy=False)
|
||||
|
||||
# Iterate response queue and yield each response to caller..
|
||||
finished = False
|
||||
while not finished:
|
||||
response = await q.get()
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Decode Request.")
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[PDResponse]:
|
||||
# Start loops on first request.
|
||||
if self.output_handler is None:
|
||||
self.output_handler = asyncio.create_task(self._run_output_handler())
|
||||
self.log_running = asyncio.create_task(self._run_log_running())
|
||||
|
||||
# TODO: expand to suppo
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("We currently only support text inputs!")
|
||||
if request_id in self.queues:
|
||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||
|
||||
# Queue to gather output from output_handler.
|
||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||
self.queues[request_id] = q
|
||||
|
||||
# (1) Perform the prefill (max_tokens=1).
|
||||
original_max_tokens = sampling_params.max_tokens
|
||||
request = PDRequest(request_id, prompt, sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
response = await self._prefill(request, q)
|
||||
yield response
|
||||
|
||||
# (2) Perform the decodes (original tokens).
|
||||
request.sampling_params.max_tokens = original_max_tokens
|
||||
async for response in self._decode(request, q):
|
||||
yield response
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def check_health(self) -> None:
|
||||
pass
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def wake_up(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
False
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_disagg_connector(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs)
|
||||
|
||||
# NOTE FOR DEVELOPERS: when we shift this to TCP, we must
|
||||
# ensure that the serialization is not pickle based to
|
||||
# avoid RCE issues from untrusted users!!!
|
||||
app.state.port = args.port
|
||||
app.state.connector_addr = f"ipc://{args.connector_addr}"
|
||||
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
||||
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
|
||||
# init uvicorn server
|
||||
config = uvicorn.Config(app, host=args.post, port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--connector-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc connector address")
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
parser = make_arg_parser(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
Reference in New Issue
Block a user