1. replace tpc:// with ipc:// \n 2. fix json response

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark
2025-01-11 20:44:53 +08:00
parent 2c31e4c3ea
commit 7fbf70db57
4 changed files with 88 additions and 38 deletions

View File

@ -21,6 +21,7 @@ async def test_connect_completions(session):
"repetition_penalty": 1.2, "repetition_penalty": 1.2,
"model": "facebook/opt-125m", "model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?", "prompt": "Can you introduce vllm?",
# "stream": False,
"stream": True, "stream": True,
"stream_options": { "stream_options": {
"include_usage": True "include_usage": True
@ -34,13 +35,19 @@ async def test_connect_completions(session):
responseText = "" responseText = ""
if response.status == 200: if response.status == 200:
transfer_encoding = response.headers.get('Transfer-Encoding') transfer_encoding = response.headers.get('Transfer-Encoding')
content_type = response.headers.get('Content-Type')
print(f"Transfer-Encoding: {transfer_encoding}")
if transfer_encoding == 'chunked': if transfer_encoding == 'chunked':
async for chunk in response.content.iter_chunked(1024): async for chunk in response.content.iter_chunked(1024):
try: try:
decoded_chunk = chunk.decode('utf-8') decoded_chunk = chunk.decode('utf-8')
print(f"Decoded chunk: {decoded_chunk!r}")
responseText += decoded_chunk responseText += decoded_chunk
except UnicodeDecodeError: except UnicodeDecodeError:
print(f"Error decoding chunk: {chunk!r}") print(f"Error decoding chunk: {chunk!r}")
elif 'application/json' in content_type:
responseText = await response.json()
print(f"response {responseText!r}")
else: else:
# Print the headers and JSON response # Print the headers and JSON response
print("Unexpected Transfer-Encoding: {} {} {}".format( print("Unexpected Transfer-Encoding: {} {} {}".format(
@ -48,18 +55,30 @@ async def test_connect_completions(session):
response.json())) response.json()))
else: else:
print(f"Request failed with status code {response.status}") print(f"Request failed with status code {response.status}")
print( print(f"baseurl {base_url}")
f"baseurl {base_url} response data {extract_data(responseText)}" print(f"response data {extract_data(responseText)}")
)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
print(f"Error: {e}") print(f"Error: {e}")
def is_json(data):
try:
json.loads(data)
return True
except ValueError:
return False
def extract_data(responseText): def extract_data(responseText):
if responseText == "":
return ""
if is_json(responseText):
return responseText
reply = "" reply = ""
for data in responseText.split("\n\n"): for data in responseText.split("\n\n"):
if data.startswith('data: '): if data.startswith('data: '):
content = data[6:] content = data[6:]
if content == "[DONE]":
print("DONE")
break
try: try:
json_data = json.loads(content) json_data = json.loads(content)
choices = json_data["choices"] choices = json_data["choices"]

View File

@ -6,20 +6,23 @@ import uuid
# from fastapi.lifespan import Lifespan # from fastapi.lifespan import Lifespan
from asyncio import Queue from asyncio import Queue
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator
import uvicorn import uvicorn
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from vllm.logger import init_logger from vllm.logger import init_logger
# default prefill and decode url # default prefill and decode addr
url_prefill = "tcp://localhost:8110" fastapi_port = 8001
prefill_addr = "ipc://localhost:7010"
socket_prefill_num = 5 socket_prefill_num = 5
url_decode = "tcp://localhost:8220" decode_addr = "ipc://localhost:7020"
socket_decode_num = 5 socket_decode_num = 5
context_type_json = "application/json"
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.connect') logger = init_logger('vllm.entrypoints.connect')
@ -77,16 +80,44 @@ async def execute_task_async(route: str, headers: dict, request: dict,
while True: while True:
logger.info("Waiting for reply") logger.info("Waiting for reply")
[contentType, reply] = await sock.recv_multipart() [contentType, reply] = await sock.recv_multipart()
logger.info("Received result: %s, %s", contentType, reply) contentType_str = contentType.decode()
reply = reply.decode() reply_str = reply.decode()
yield f"{reply}" logger.info("Received result: %s, %s", contentType_str, reply_str)
if "[DONE]" in reply: yield (contentType_str, reply_str)
if context_type_json == contentType_str:
logger.info("Received %s message, return socket",
contentType_str)
break
if "[DONE]" in reply_str:
logger.info("Received stop signal, return socket") logger.info("Received stop signal, return socket")
break break
finally: finally:
await sockets.put(sock) await sockets.put(sock)
async def generate_stream_response(fisrt_reply: str,
generator: AsyncGenerator):
yield fisrt_reply
async for _, reply in generator:
yield reply
async def decode(route: str, header: dict, original_request_data: dict):
logger.info("start decode")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_decode)
logger.info("finish decode")
async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_json == contentType:
return JSONResponse(reply)
else:
return StreamingResponse(generate_stream_response(
reply, generator),
media_type="text/event-stream")
@app.post('/v1/connect/completions') @app.post('/v1/connect/completions')
async def chat_completions(request: Request): async def chat_completions(request: Request):
try: try:
@ -108,11 +139,9 @@ async def chat_completions(request: Request):
app.state.sockets_prefill): app.state.sockets_prefill):
continue continue
# return decode logger.info("finish prefill start decode")
return StreamingResponse(execute_task_async(route, header, response = await decode(route, header, original_request_data)
original_request_data, return response
app.state.sockets_decode),
media_type="text/event-stream")
except Exception as e: except Exception as e:
import sys import sys
@ -127,13 +156,14 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
logger.info("vLLM Disaggregate Connector start %s %s", args, logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs) uvicorn_kwargs)
logger.info(args.prefill_addr) logger.info(args.prefill_addr)
app.state.port = args.port if args.port is not None else fastapi_port
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr
is not None else url_prefill) is not None else decode_addr)
app.state.decode_addr = (f"tcp://{args.decode_addr}" app.state.decode_addr = (f"ipc://{args.decode_addr}"
if args.decode_addr is not None else url_decode) if args.decode_addr is not None else decode_addr)
logger.info("start connect url_prefill: %s url_decode: %s", logger.info(
app.state.prefill_addr, app.state.decode_addr) "start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
app.state.prefill_addr, app.state.decode_addr, app.state.port)
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing # Interrupt server on sigterm while initializing
@ -141,11 +171,11 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
# init uvicorn server # init uvicorn server
config = uvicorn.Config(app, host="0.0.0.0", port=8001) config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
server = uvicorn.Server(config) server = uvicorn.Server(config)
await server.serve() await server.serve()
if __name__ == "__main__": if __name__ == "__main__":
# url = 'tcp://127.0.0.1:5555' # url = 'tcp://127.0.0.1:5555'
uvicorn.run(app, host="0.0.0.0", port=8001) uvicorn.run(app, host="0.0.0.0", port=fastapi_port)

View File

@ -82,22 +82,22 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine""" """Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port) zmq_server_port)
url_worker = "inproc://workers" workers_addr = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}" clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets # Prepare our context and sockets
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
# Socket to talk to clients # Socket to talk to clients
clients = context.socket(zmq.ROUTER) clients = context.socket(zmq.ROUTER)
clients.bind(url_client) clients.bind(clients_addr)
logger.info("ZMQ Server ROUTER started at %s", url_client) logger.info("ZMQ Server ROUTER started at %s", clients_addr)
# Socket to talk to workers # Socket to talk to workers
workers = context.socket(zmq.DEALER) workers = context.socket(zmq.DEALER)
workers.bind(url_worker) workers.bind(workers_addr)
logger.info("ZMQ Worker DEALER started at %s", url_worker) logger.info("ZMQ Worker DEALER started at %s", workers_addr)
tasks = [ tasks = [
asyncio.create_task(worker_routine(url_worker, app, context, i)) asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(5) for i in range(5)
] ]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)

View File

@ -53,7 +53,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
headers_dict = json.loads(bytes_data.decode()) headers_dict = json.loads(bytes_data.decode())
return httpx.Headers(headers_dict) return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI, async def worker_routine(worker_addr: str, app: FastAPI,
context: zmq.asyncio.Context, i: int = 0): context: zmq.asyncio.Context, i: int = 0):
"""Worker routine""" """Worker routine"""
try: try:
@ -61,8 +61,8 @@ async def worker_routine(worker_url: str, app: FastAPI,
socket = context.socket(zmq.DEALER) socket = context.socket(zmq.DEALER)
worker_identity = f"worker-{i}-{uuid.uuid4()}" worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_url) socket.connect(worker_addr)
logger.info("%s started at %s", worker_identity, worker_url) logger.info("%s started at %s", worker_identity, worker_addr)
while True: while True:
identity, url, header, body = await socket.recv_multipart() identity, url, header, body = await socket.recv_multipart()
logger.info("worker-%d Received request identity: [ %s ]", logger.info("worker-%d Received request identity: [ %s ]",
@ -81,15 +81,16 @@ async def worker_routine(worker_url: str, app: FastAPI,
createRequest = create_request(url_str, "POST", body_json, headers) createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest, generator = await create_completion(app, completionRequest,
createRequest) createRequest)
logger.info("worker-%d Received response: [ %s ]", i, generator) context_type_json = b"application/json"
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
content = generator.model_dump_json() content = generator.model_dump_json()
context_json = json.loads(content) context_json = json.loads(content)
context_json.append("status_code", generator.code) context_json.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", await socket.send_multipart([identity, context_type_json,
json.dumps(context_json).encode('utf-8')]) json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse): elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", await socket.send_multipart([identity,
context_type_json,
json.dumps(generator.model_dump()).encode('utf-8')]) json.dumps(generator.model_dump()).encode('utf-8')])
else: else:
async for chunk in generator: async for chunk in generator: