1. replace tpc:// with ipc:// \n 2. fix json response

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark
2025-01-11 20:44:53 +08:00
parent 2c31e4c3ea
commit 7fbf70db57
4 changed files with 88 additions and 38 deletions

View File

@ -21,6 +21,7 @@ async def test_connect_completions(session):
"repetition_penalty": 1.2,
"model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?",
# "stream": False,
"stream": True,
"stream_options": {
"include_usage": True
@ -34,13 +35,19 @@ async def test_connect_completions(session):
responseText = ""
if response.status == 200:
transfer_encoding = response.headers.get('Transfer-Encoding')
content_type = response.headers.get('Content-Type')
print(f"Transfer-Encoding: {transfer_encoding}")
if transfer_encoding == 'chunked':
async for chunk in response.content.iter_chunked(1024):
try:
decoded_chunk = chunk.decode('utf-8')
print(f"Decoded chunk: {decoded_chunk!r}")
responseText += decoded_chunk
except UnicodeDecodeError:
print(f"Error decoding chunk: {chunk!r}")
elif 'application/json' in content_type:
responseText = await response.json()
print(f"response {responseText!r}")
else:
# Print the headers and JSON response
print("Unexpected Transfer-Encoding: {} {} {}".format(
@ -48,18 +55,30 @@ async def test_connect_completions(session):
response.json()))
else:
print(f"Request failed with status code {response.status}")
print(
f"baseurl {base_url} response data {extract_data(responseText)}"
)
print(f"baseurl {base_url}")
print(f"response data {extract_data(responseText)}")
except aiohttp.ClientError as e:
print(f"Error: {e}")
def is_json(data):
try:
json.loads(data)
return True
except ValueError:
return False
def extract_data(responseText):
if responseText == "":
return ""
if is_json(responseText):
return responseText
reply = ""
for data in responseText.split("\n\n"):
if data.startswith('data: '):
content = data[6:]
if content == "[DONE]":
print("DONE")
break
try:
json_data = json.loads(content)
choices = json_data["choices"]

View File

@ -6,20 +6,23 @@ import uuid
# from fastapi.lifespan import Lifespan
from asyncio import Queue
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.logger import init_logger
# default prefill and decode url
url_prefill = "tcp://localhost:8110"
# default prefill and decode addr
fastapi_port = 8001
prefill_addr = "ipc://localhost:7010"
socket_prefill_num = 5
url_decode = "tcp://localhost:8220"
decode_addr = "ipc://localhost:7020"
socket_decode_num = 5
context_type_json = "application/json"
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.connect')
@ -77,16 +80,44 @@ async def execute_task_async(route: str, headers: dict, request: dict,
while True:
logger.info("Waiting for reply")
[contentType, reply] = await sock.recv_multipart()
logger.info("Received result: %s, %s", contentType, reply)
reply = reply.decode()
yield f"{reply}"
if "[DONE]" in reply:
contentType_str = contentType.decode()
reply_str = reply.decode()
logger.info("Received result: %s, %s", contentType_str, reply_str)
yield (contentType_str, reply_str)
if context_type_json == contentType_str:
logger.info("Received %s message, return socket",
contentType_str)
break
if "[DONE]" in reply_str:
logger.info("Received stop signal, return socket")
break
finally:
await sockets.put(sock)
async def generate_stream_response(fisrt_reply: str,
generator: AsyncGenerator):
yield fisrt_reply
async for _, reply in generator:
yield reply
async def decode(route: str, header: dict, original_request_data: dict):
logger.info("start decode")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_decode)
logger.info("finish decode")
async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_json == contentType:
return JSONResponse(reply)
else:
return StreamingResponse(generate_stream_response(
reply, generator),
media_type="text/event-stream")
@app.post('/v1/connect/completions')
async def chat_completions(request: Request):
try:
@ -108,11 +139,9 @@ async def chat_completions(request: Request):
app.state.sockets_prefill):
continue
# return decode
return StreamingResponse(execute_task_async(route, header,
original_request_data,
app.state.sockets_decode),
media_type="text/event-stream")
logger.info("finish prefill start decode")
response = await decode(route, header, original_request_data)
return response
except Exception as e:
import sys
@ -127,13 +156,14 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
is not None else url_prefill)
app.state.decode_addr = (f"tcp://{args.decode_addr}"
if args.decode_addr is not None else url_decode)
logger.info("start connect url_prefill: %s url_decode: %s",
app.state.prefill_addr, app.state.decode_addr)
app.state.port = args.port if args.port is not None else fastapi_port
app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr
is not None else decode_addr)
app.state.decode_addr = (f"ipc://{args.decode_addr}"
if args.decode_addr is not None else decode_addr)
logger.info(
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
app.state.prefill_addr, app.state.decode_addr, app.state.port)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
@ -141,11 +171,11 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
# init uvicorn server
config = uvicorn.Config(app, host="0.0.0.0", port=8001)
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
# url = 'tcp://127.0.0.1:5555'
uvicorn.run(app, host="0.0.0.0", port=8001)
uvicorn.run(app, host="0.0.0.0", port=fastapi_port)

View File

@ -82,22 +82,22 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
url_worker = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
workers_addr = "inproc://workers"
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context()
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(url_client)
logger.info("ZMQ Server ROUTER started at %s", url_client)
clients.bind(clients_addr)
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(url_worker)
logger.info("ZMQ Worker DEALER started at %s", url_worker)
workers.bind(workers_addr)
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
tasks = [
asyncio.create_task(worker_routine(url_worker, app, context, i))
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(5)
]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)

View File

@ -53,7 +53,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
headers_dict = json.loads(bytes_data.decode())
return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI,
async def worker_routine(worker_addr: str, app: FastAPI,
context: zmq.asyncio.Context, i: int = 0):
"""Worker routine"""
try:
@ -61,8 +61,8 @@ async def worker_routine(worker_url: str, app: FastAPI,
socket = context.socket(zmq.DEALER)
worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_url)
logger.info("%s started at %s", worker_identity, worker_url)
socket.connect(worker_addr)
logger.info("%s started at %s", worker_identity, worker_addr)
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info("worker-%d Received request identity: [ %s ]",
@ -81,15 +81,16 @@ async def worker_routine(worker_url: str, app: FastAPI,
createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest,
createRequest)
logger.info("worker-%d Received response: [ %s ]", i, generator)
context_type_json = b"application/json"
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context_json = json.loads(content)
context_json.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json",
await socket.send_multipart([identity, context_type_json,
json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json",
await socket.send_multipart([identity,
context_type_json,
json.dumps(generator.model_dump()).encode('utf-8')])
else:
async for chunk in generator: