mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
@ -4,7 +4,6 @@ import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.datastructures import Headers
|
||||
from fastapi.responses import StreamingResponse
|
||||
from contextlib import asynccontextmanager
|
||||
# from fastapi.lifespan import Lifespan
|
||||
@ -24,7 +23,7 @@ logger = init_logger('vllm.entrypoints.connect')
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# create scoket pool with prefill and decode
|
||||
# create socket pool with prefill and decode
|
||||
logger.info("start create_socket_pool")
|
||||
app.state.zmqctx = zmq.asyncio.Context()
|
||||
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||
@ -39,7 +38,7 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context):
|
||||
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
|
||||
sockets = Queue()
|
||||
for i in range(num_sockets):
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
@ -50,8 +49,8 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
|
||||
await sockets.put(sock)
|
||||
return sockets
|
||||
|
||||
# select a scoket and execute task
|
||||
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
|
||||
sock = await sockets.get()
|
||||
try:
|
||||
requestBody = json.dumps(request)
|
||||
|
@ -1,12 +1,13 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
import httpx
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.logger import init_logger
|
||||
import traceback
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
@ -54,7 +54,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
||||
return httpx.Headers(headers_dict)
|
||||
|
||||
async def worker_routine(worker_url: str, app: FastAPI,
|
||||
context: zmq.asyncio.Context = None, i: int = 0):
|
||||
context: zmq.asyncio.Context, i: int = 0):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
# Socket to talk to dispatcher
|
||||
@ -65,46 +65,46 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
||||
logger.info(f"{worker_identity} started at {worker_url}")
|
||||
while True:
|
||||
identity, url, header, body = await socket.recv_multipart()
|
||||
logger.info(f"worker-{i} Received request identity: [{identity} ]")
|
||||
url = url.decode()
|
||||
logger.info(f"worker-{i} Received request url: [{url} ]")
|
||||
header = bytes_to_headers(header)
|
||||
logger.info(f"worker-{i} Received request headers: [{header} ]")
|
||||
body = json.loads(body.decode())
|
||||
logger.info(f"worker-{i} Received request body: [{body} ]")
|
||||
logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
|
||||
url_str = url.decode()
|
||||
logger.info(f"worker-{i} Received request url: [{url_str} ]")
|
||||
headers = bytes_to_headers(header)
|
||||
logger.info(f"worker-{i} Received request headers: [{headers} ]")
|
||||
body_json = json.loads(body.decode())
|
||||
logger.info(f"worker-{i} Received request body: [{body_json} ]")
|
||||
logger.info(f"worker-{i} Calling OpenAI API")
|
||||
completionRequest = CompletionRequest(**body)
|
||||
createRequest = create_request(url, "POST", body, header)
|
||||
completionRequest = CompletionRequest(**body_json)
|
||||
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||
generator = await create_completion(app, completionRequest, createRequest)
|
||||
logger.info(f"worker-{i} Received response: [{generator} ]")
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = generator.model_dump_json()
|
||||
context = json.loads(content)
|
||||
context.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
|
||||
context_json = json.loads(content)
|
||||
context_json.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
||||
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
|
||||
await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker routine: {e} worker-{i}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
||||
handler = completion(app)
|
||||
logger.info(f"zmq requset post: {request}")
|
||||
logger.info(f"zmq request post: {request}")
|
||||
if handler is None:
|
||||
return base(app).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
logger.info(f"zmq requset end post: {generator}")
|
||||
logger.info(f"zmq request end post: {generator}")
|
||||
return generator
|
||||
|
||||
|
||||
def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
||||
def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
@ -113,7 +113,7 @@ def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
||||
'headers': list(headers.items()) if headers else [],
|
||||
}
|
||||
if body:
|
||||
scope['body'] = json.dumps(body).encode('utf-8')
|
||||
scope['body'] = json.dumps(body)
|
||||
async def receive():
|
||||
return {
|
||||
'type': 'http.request',
|
||||
|
Reference in New Issue
Block a user