mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
1. replace tpc:// with ipc:// \n 2. fix json response
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
@ -21,6 +21,7 @@ async def test_connect_completions(session):
|
||||
"repetition_penalty": 1.2,
|
||||
"model": "facebook/opt-125m",
|
||||
"prompt": "Can you introduce vllm?",
|
||||
# "stream": False,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True
|
||||
@ -34,13 +35,19 @@ async def test_connect_completions(session):
|
||||
responseText = ""
|
||||
if response.status == 200:
|
||||
transfer_encoding = response.headers.get('Transfer-Encoding')
|
||||
content_type = response.headers.get('Content-Type')
|
||||
print(f"Transfer-Encoding: {transfer_encoding}")
|
||||
if transfer_encoding == 'chunked':
|
||||
async for chunk in response.content.iter_chunked(1024):
|
||||
try:
|
||||
decoded_chunk = chunk.decode('utf-8')
|
||||
print(f"Decoded chunk: {decoded_chunk!r}")
|
||||
responseText += decoded_chunk
|
||||
except UnicodeDecodeError:
|
||||
print(f"Error decoding chunk: {chunk!r}")
|
||||
elif 'application/json' in content_type:
|
||||
responseText = await response.json()
|
||||
print(f"response {responseText!r}")
|
||||
else:
|
||||
# Print the headers and JSON response
|
||||
print("Unexpected Transfer-Encoding: {} {} {}".format(
|
||||
@ -48,18 +55,30 @@ async def test_connect_completions(session):
|
||||
response.json()))
|
||||
else:
|
||||
print(f"Request failed with status code {response.status}")
|
||||
print(
|
||||
f"baseurl {base_url} response data {extract_data(responseText)}"
|
||||
)
|
||||
print(f"baseurl {base_url}")
|
||||
print(f"response data {extract_data(responseText)}")
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
def is_json(data):
|
||||
try:
|
||||
json.loads(data)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def extract_data(responseText):
|
||||
if responseText == "":
|
||||
return ""
|
||||
if is_json(responseText):
|
||||
return responseText
|
||||
reply = ""
|
||||
for data in responseText.split("\n\n"):
|
||||
if data.startswith('data: '):
|
||||
content = data[6:]
|
||||
if content == "[DONE]":
|
||||
print("DONE")
|
||||
break
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
choices = json_data["choices"]
|
||||
|
@ -6,20 +6,23 @@ import uuid
|
||||
# from fastapi.lifespan import Lifespan
|
||||
from asyncio import Queue
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# default prefill and decode url
|
||||
url_prefill = "tcp://localhost:8110"
|
||||
# default prefill and decode addr
|
||||
fastapi_port = 8001
|
||||
prefill_addr = "ipc://localhost:7010"
|
||||
socket_prefill_num = 5
|
||||
url_decode = "tcp://localhost:8220"
|
||||
decode_addr = "ipc://localhost:7020"
|
||||
socket_decode_num = 5
|
||||
context_type_json = "application/json"
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.connect')
|
||||
@ -77,16 +80,44 @@ async def execute_task_async(route: str, headers: dict, request: dict,
|
||||
while True:
|
||||
logger.info("Waiting for reply")
|
||||
[contentType, reply] = await sock.recv_multipart()
|
||||
logger.info("Received result: %s, %s", contentType, reply)
|
||||
reply = reply.decode()
|
||||
yield f"{reply}"
|
||||
if "[DONE]" in reply:
|
||||
contentType_str = contentType.decode()
|
||||
reply_str = reply.decode()
|
||||
logger.info("Received result: %s, %s", contentType_str, reply_str)
|
||||
yield (contentType_str, reply_str)
|
||||
if context_type_json == contentType_str:
|
||||
logger.info("Received %s message, return socket",
|
||||
contentType_str)
|
||||
break
|
||||
if "[DONE]" in reply_str:
|
||||
logger.info("Received stop signal, return socket")
|
||||
break
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
|
||||
|
||||
async def generate_stream_response(fisrt_reply: str,
|
||||
generator: AsyncGenerator):
|
||||
yield fisrt_reply
|
||||
async for _, reply in generator:
|
||||
yield reply
|
||||
|
||||
|
||||
async def decode(route: str, header: dict, original_request_data: dict):
|
||||
logger.info("start decode")
|
||||
generator = execute_task_async(route, header, original_request_data,
|
||||
app.state.sockets_decode)
|
||||
logger.info("finish decode")
|
||||
|
||||
async for contentType, reply in generator:
|
||||
logger.info("contentType: %s, reply: %s", contentType, reply)
|
||||
if context_type_json == contentType:
|
||||
return JSONResponse(reply)
|
||||
else:
|
||||
return StreamingResponse(generate_stream_response(
|
||||
reply, generator),
|
||||
media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.post('/v1/connect/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
@ -108,11 +139,9 @@ async def chat_completions(request: Request):
|
||||
app.state.sockets_prefill):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
return StreamingResponse(execute_task_async(route, header,
|
||||
original_request_data,
|
||||
app.state.sockets_decode),
|
||||
media_type="text/event-stream")
|
||||
logger.info("finish prefill start decode")
|
||||
response = await decode(route, header, original_request_data)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
@ -127,13 +156,14 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
logger.info(args.prefill_addr)
|
||||
|
||||
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
|
||||
is not None else url_prefill)
|
||||
app.state.decode_addr = (f"tcp://{args.decode_addr}"
|
||||
if args.decode_addr is not None else url_decode)
|
||||
logger.info("start connect url_prefill: %s url_decode: %s",
|
||||
app.state.prefill_addr, app.state.decode_addr)
|
||||
app.state.port = args.port if args.port is not None else fastapi_port
|
||||
app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr
|
||||
is not None else decode_addr)
|
||||
app.state.decode_addr = (f"ipc://{args.decode_addr}"
|
||||
if args.decode_addr is not None else decode_addr)
|
||||
logger.info(
|
||||
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
|
||||
app.state.prefill_addr, app.state.decode_addr, app.state.port)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
@ -141,11 +171,11 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
# init uvicorn server
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8001)
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# url = 'tcp://127.0.0.1:5555'
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
uvicorn.run(app, host="0.0.0.0", port=fastapi_port)
|
||||
|
@ -82,22 +82,22 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||
zmq_server_port)
|
||||
url_worker = "inproc://workers"
|
||||
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
|
||||
workers_addr = "inproc://workers"
|
||||
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context()
|
||||
|
||||
# Socket to talk to clients
|
||||
clients = context.socket(zmq.ROUTER)
|
||||
clients.bind(url_client)
|
||||
logger.info("ZMQ Server ROUTER started at %s", url_client)
|
||||
clients.bind(clients_addr)
|
||||
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
|
||||
# Socket to talk to workers
|
||||
workers = context.socket(zmq.DEALER)
|
||||
workers.bind(url_worker)
|
||||
logger.info("ZMQ Worker DEALER started at %s", url_worker)
|
||||
workers.bind(workers_addr)
|
||||
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(url_worker, app, context, i))
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(5)
|
||||
]
|
||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||
|
@ -53,7 +53,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
||||
headers_dict = json.loads(bytes_data.decode())
|
||||
return httpx.Headers(headers_dict)
|
||||
|
||||
async def worker_routine(worker_url: str, app: FastAPI,
|
||||
async def worker_routine(worker_addr: str, app: FastAPI,
|
||||
context: zmq.asyncio.Context, i: int = 0):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
@ -61,8 +61,8 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
||||
socket = context.socket(zmq.DEALER)
|
||||
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||
socket.connect(worker_url)
|
||||
logger.info("%s started at %s", worker_identity, worker_url)
|
||||
socket.connect(worker_addr)
|
||||
logger.info("%s started at %s", worker_identity, worker_addr)
|
||||
while True:
|
||||
identity, url, header, body = await socket.recv_multipart()
|
||||
logger.info("worker-%d Received request identity: [ %s ]",
|
||||
@ -81,15 +81,16 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
||||
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||
generator = await create_completion(app, completionRequest,
|
||||
createRequest)
|
||||
logger.info("worker-%d Received response: [ %s ]", i, generator)
|
||||
context_type_json = b"application/json"
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = generator.model_dump_json()
|
||||
context_json = json.loads(content)
|
||||
context_json.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, b"application/json",
|
||||
await socket.send_multipart([identity, context_type_json,
|
||||
json.dumps(context_json).encode('utf-8')])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
await socket.send_multipart([identity, b"application/json",
|
||||
await socket.send_multipart([identity,
|
||||
context_type_json,
|
||||
json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
|
Reference in New Issue
Block a user