1. fix mypy issue

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark
2025-01-08 23:15:23 +08:00
parent 897db7b93d
commit 187f112ccd
2 changed files with 26 additions and 27 deletions

View File

@ -4,7 +4,6 @@ import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from starlette.datastructures import Headers
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
# from fastapi.lifespan import Lifespan
@ -24,7 +23,7 @@ logger = init_logger('vllm.entrypoints.connect')
@asynccontextmanager
async def lifespan(app: FastAPI):
# create scoket pool with prefill and decode
# create socket pool with prefill and decode
logger.info("start create_socket_pool")
app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
@ -39,7 +38,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context):
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
sockets = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
@ -50,8 +49,8 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
await sockets.put(sock)
return sockets
# select a scoket and execute task
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
sock = await sockets.get()
try:
requestBody = json.dumps(request)

View File

@ -1,12 +1,13 @@
import json
from typing import Optional
import zmq
import zmq.asyncio
import tempfile
import uuid
import httpx
import json
import traceback
from typing import Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.logger import init_logger
import traceback
prometheus_multiproc_dir: tempfile.TemporaryDirectory
@ -54,7 +54,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI,
context: zmq.asyncio.Context = None, i: int = 0):
context: zmq.asyncio.Context, i: int = 0):
"""Worker routine"""
try:
# Socket to talk to dispatcher
@ -65,46 +65,46 @@ async def worker_routine(worker_url: str, app: FastAPI,
logger.info(f"{worker_identity} started at {worker_url}")
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info(f"worker-{i} Received request identity: [{identity} ]")
url = url.decode()
logger.info(f"worker-{i} Received request url: [{url} ]")
header = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{header} ]")
body = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body} ]")
logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
url_str = url.decode()
logger.info(f"worker-{i} Received request url: [{url_str} ]")
headers = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{headers} ]")
body_json = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body_json} ]")
logger.info(f"worker-{i} Calling OpenAI API")
completionRequest = CompletionRequest(**body)
createRequest = create_request(url, "POST", body, header)
completionRequest = CompletionRequest(**body_json)
createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest, createRequest)
logger.info(f"worker-{i} Received response: [{generator} ]")
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context = json.loads(content)
context.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
context_json = json.loads(content)
context_json.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
else:
async for chunk in generator:
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
except Exception as e:
logger.error(f"Error in worker routine: {e} worker-{i}")
logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
handler = completion(app)
logger.info(f"zmq requset post: {request}")
logger.info(f"zmq request post: {request}")
if handler is None:
return base(app).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
logger.info(f"zmq requset end post: {generator}")
logger.info(f"zmq request end post: {generator}")
return generator
def create_request(path: str, method: str, body: bytes, headers: dict = None):
def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
scope = {
'type': 'http',
'http_version': '1.1',
@ -113,7 +113,7 @@ def create_request(path: str, method: str, body: bytes, headers: dict = None):
'headers': list(headers.items()) if headers else [],
}
if body:
scope['body'] = json.dumps(body).encode('utf-8')
scope['body'] = json.dumps(body)
async def receive():
return {
'type': 'http.request',