mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
1. replace tpc:// with ipc:// \n 2. fix json response
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
@ -21,6 +21,7 @@ async def test_connect_completions(session):
|
|||||||
"repetition_penalty": 1.2,
|
"repetition_penalty": 1.2,
|
||||||
"model": "facebook/opt-125m",
|
"model": "facebook/opt-125m",
|
||||||
"prompt": "Can you introduce vllm?",
|
"prompt": "Can you introduce vllm?",
|
||||||
|
# "stream": False,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {
|
"stream_options": {
|
||||||
"include_usage": True
|
"include_usage": True
|
||||||
@ -34,13 +35,19 @@ async def test_connect_completions(session):
|
|||||||
responseText = ""
|
responseText = ""
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
transfer_encoding = response.headers.get('Transfer-Encoding')
|
transfer_encoding = response.headers.get('Transfer-Encoding')
|
||||||
|
content_type = response.headers.get('Content-Type')
|
||||||
|
print(f"Transfer-Encoding: {transfer_encoding}")
|
||||||
if transfer_encoding == 'chunked':
|
if transfer_encoding == 'chunked':
|
||||||
async for chunk in response.content.iter_chunked(1024):
|
async for chunk in response.content.iter_chunked(1024):
|
||||||
try:
|
try:
|
||||||
decoded_chunk = chunk.decode('utf-8')
|
decoded_chunk = chunk.decode('utf-8')
|
||||||
|
print(f"Decoded chunk: {decoded_chunk!r}")
|
||||||
responseText += decoded_chunk
|
responseText += decoded_chunk
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
print(f"Error decoding chunk: {chunk!r}")
|
print(f"Error decoding chunk: {chunk!r}")
|
||||||
|
elif 'application/json' in content_type:
|
||||||
|
responseText = await response.json()
|
||||||
|
print(f"response {responseText!r}")
|
||||||
else:
|
else:
|
||||||
# Print the headers and JSON response
|
# Print the headers and JSON response
|
||||||
print("Unexpected Transfer-Encoding: {} {} {}".format(
|
print("Unexpected Transfer-Encoding: {} {} {}".format(
|
||||||
@ -48,18 +55,30 @@ async def test_connect_completions(session):
|
|||||||
response.json()))
|
response.json()))
|
||||||
else:
|
else:
|
||||||
print(f"Request failed with status code {response.status}")
|
print(f"Request failed with status code {response.status}")
|
||||||
print(
|
print(f"baseurl {base_url}")
|
||||||
f"baseurl {base_url} response data {extract_data(responseText)}"
|
print(f"response data {extract_data(responseText)}")
|
||||||
)
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
def is_json(data):
|
||||||
|
try:
|
||||||
|
json.loads(data)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
def extract_data(responseText):
|
def extract_data(responseText):
|
||||||
|
if responseText == "":
|
||||||
|
return ""
|
||||||
|
if is_json(responseText):
|
||||||
|
return responseText
|
||||||
reply = ""
|
reply = ""
|
||||||
for data in responseText.split("\n\n"):
|
for data in responseText.split("\n\n"):
|
||||||
if data.startswith('data: '):
|
if data.startswith('data: '):
|
||||||
content = data[6:]
|
content = data[6:]
|
||||||
|
if content == "[DONE]":
|
||||||
|
print("DONE")
|
||||||
|
break
|
||||||
try:
|
try:
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
choices = json_data["choices"]
|
choices = json_data["choices"]
|
||||||
|
|||||||
@ -6,20 +6,23 @@ import uuid
|
|||||||
# from fastapi.lifespan import Lifespan
|
# from fastapi.lifespan import Lifespan
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
# default prefill and decode url
|
# default prefill and decode addr
|
||||||
url_prefill = "tcp://localhost:8110"
|
fastapi_port = 8001
|
||||||
|
prefill_addr = "ipc://localhost:7010"
|
||||||
socket_prefill_num = 5
|
socket_prefill_num = 5
|
||||||
url_decode = "tcp://localhost:8220"
|
decode_addr = "ipc://localhost:7020"
|
||||||
socket_decode_num = 5
|
socket_decode_num = 5
|
||||||
|
context_type_json = "application/json"
|
||||||
|
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
logger = init_logger('vllm.entrypoints.connect')
|
logger = init_logger('vllm.entrypoints.connect')
|
||||||
@ -77,16 +80,44 @@ async def execute_task_async(route: str, headers: dict, request: dict,
|
|||||||
while True:
|
while True:
|
||||||
logger.info("Waiting for reply")
|
logger.info("Waiting for reply")
|
||||||
[contentType, reply] = await sock.recv_multipart()
|
[contentType, reply] = await sock.recv_multipart()
|
||||||
logger.info("Received result: %s, %s", contentType, reply)
|
contentType_str = contentType.decode()
|
||||||
reply = reply.decode()
|
reply_str = reply.decode()
|
||||||
yield f"{reply}"
|
logger.info("Received result: %s, %s", contentType_str, reply_str)
|
||||||
if "[DONE]" in reply:
|
yield (contentType_str, reply_str)
|
||||||
|
if context_type_json == contentType_str:
|
||||||
|
logger.info("Received %s message, return socket",
|
||||||
|
contentType_str)
|
||||||
|
break
|
||||||
|
if "[DONE]" in reply_str:
|
||||||
logger.info("Received stop signal, return socket")
|
logger.info("Received stop signal, return socket")
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
await sockets.put(sock)
|
await sockets.put(sock)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_stream_response(fisrt_reply: str,
|
||||||
|
generator: AsyncGenerator):
|
||||||
|
yield fisrt_reply
|
||||||
|
async for _, reply in generator:
|
||||||
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
|
async def decode(route: str, header: dict, original_request_data: dict):
|
||||||
|
logger.info("start decode")
|
||||||
|
generator = execute_task_async(route, header, original_request_data,
|
||||||
|
app.state.sockets_decode)
|
||||||
|
logger.info("finish decode")
|
||||||
|
|
||||||
|
async for contentType, reply in generator:
|
||||||
|
logger.info("contentType: %s, reply: %s", contentType, reply)
|
||||||
|
if context_type_json == contentType:
|
||||||
|
return JSONResponse(reply)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(generate_stream_response(
|
||||||
|
reply, generator),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/connect/completions')
|
@app.post('/v1/connect/completions')
|
||||||
async def chat_completions(request: Request):
|
async def chat_completions(request: Request):
|
||||||
try:
|
try:
|
||||||
@ -108,11 +139,9 @@ async def chat_completions(request: Request):
|
|||||||
app.state.sockets_prefill):
|
app.state.sockets_prefill):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# return decode
|
logger.info("finish prefill start decode")
|
||||||
return StreamingResponse(execute_task_async(route, header,
|
response = await decode(route, header, original_request_data)
|
||||||
original_request_data,
|
return response
|
||||||
app.state.sockets_decode),
|
|
||||||
media_type="text/event-stream")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import sys
|
import sys
|
||||||
@ -127,13 +156,14 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
|||||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
||||||
uvicorn_kwargs)
|
uvicorn_kwargs)
|
||||||
logger.info(args.prefill_addr)
|
logger.info(args.prefill_addr)
|
||||||
|
app.state.port = args.port if args.port is not None else fastapi_port
|
||||||
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
|
app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr
|
||||||
is not None else url_prefill)
|
is not None else decode_addr)
|
||||||
app.state.decode_addr = (f"tcp://{args.decode_addr}"
|
app.state.decode_addr = (f"ipc://{args.decode_addr}"
|
||||||
if args.decode_addr is not None else url_decode)
|
if args.decode_addr is not None else decode_addr)
|
||||||
logger.info("start connect url_prefill: %s url_decode: %s",
|
logger.info(
|
||||||
app.state.prefill_addr, app.state.decode_addr)
|
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
|
||||||
|
app.state.prefill_addr, app.state.decode_addr, app.state.port)
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
def signal_handler(*_) -> None:
|
||||||
# Interrupt server on sigterm while initializing
|
# Interrupt server on sigterm while initializing
|
||||||
@ -141,11 +171,11 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
|||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
# init uvicorn server
|
# init uvicorn server
|
||||||
config = uvicorn.Config(app, host="0.0.0.0", port=8001)
|
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# url = 'tcp://127.0.0.1:5555'
|
# url = 'tcp://127.0.0.1:5555'
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
uvicorn.run(app, host="0.0.0.0", port=fastapi_port)
|
||||||
|
|||||||
@ -82,22 +82,22 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
|||||||
"""Server routine"""
|
"""Server routine"""
|
||||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||||
zmq_server_port)
|
zmq_server_port)
|
||||||
url_worker = "inproc://workers"
|
workers_addr = "inproc://workers"
|
||||||
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
|
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||||
# Prepare our context and sockets
|
# Prepare our context and sockets
|
||||||
context = zmq.asyncio.Context()
|
context = zmq.asyncio.Context()
|
||||||
|
|
||||||
# Socket to talk to clients
|
# Socket to talk to clients
|
||||||
clients = context.socket(zmq.ROUTER)
|
clients = context.socket(zmq.ROUTER)
|
||||||
clients.bind(url_client)
|
clients.bind(clients_addr)
|
||||||
logger.info("ZMQ Server ROUTER started at %s", url_client)
|
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
|
||||||
# Socket to talk to workers
|
# Socket to talk to workers
|
||||||
workers = context.socket(zmq.DEALER)
|
workers = context.socket(zmq.DEALER)
|
||||||
workers.bind(url_worker)
|
workers.bind(workers_addr)
|
||||||
logger.info("ZMQ Worker DEALER started at %s", url_worker)
|
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(worker_routine(url_worker, app, context, i))
|
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
]
|
]
|
||||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||||
|
|||||||
@ -53,7 +53,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
|||||||
headers_dict = json.loads(bytes_data.decode())
|
headers_dict = json.loads(bytes_data.decode())
|
||||||
return httpx.Headers(headers_dict)
|
return httpx.Headers(headers_dict)
|
||||||
|
|
||||||
async def worker_routine(worker_url: str, app: FastAPI,
|
async def worker_routine(worker_addr: str, app: FastAPI,
|
||||||
context: zmq.asyncio.Context, i: int = 0):
|
context: zmq.asyncio.Context, i: int = 0):
|
||||||
"""Worker routine"""
|
"""Worker routine"""
|
||||||
try:
|
try:
|
||||||
@ -61,8 +61,8 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
|||||||
socket = context.socket(zmq.DEALER)
|
socket = context.socket(zmq.DEALER)
|
||||||
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||||
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||||
socket.connect(worker_url)
|
socket.connect(worker_addr)
|
||||||
logger.info("%s started at %s", worker_identity, worker_url)
|
logger.info("%s started at %s", worker_identity, worker_addr)
|
||||||
while True:
|
while True:
|
||||||
identity, url, header, body = await socket.recv_multipart()
|
identity, url, header, body = await socket.recv_multipart()
|
||||||
logger.info("worker-%d Received request identity: [ %s ]",
|
logger.info("worker-%d Received request identity: [ %s ]",
|
||||||
@ -81,15 +81,16 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
|||||||
createRequest = create_request(url_str, "POST", body_json, headers)
|
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||||
generator = await create_completion(app, completionRequest,
|
generator = await create_completion(app, completionRequest,
|
||||||
createRequest)
|
createRequest)
|
||||||
logger.info("worker-%d Received response: [ %s ]", i, generator)
|
context_type_json = b"application/json"
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
content = generator.model_dump_json()
|
content = generator.model_dump_json()
|
||||||
context_json = json.loads(content)
|
context_json = json.loads(content)
|
||||||
context_json.append("status_code", generator.code)
|
context_json.append("status_code", generator.code)
|
||||||
await socket.send_multipart([identity, b"application/json",
|
await socket.send_multipart([identity, context_type_json,
|
||||||
json.dumps(context_json).encode('utf-8')])
|
json.dumps(context_json).encode('utf-8')])
|
||||||
elif isinstance(generator, CompletionResponse):
|
elif isinstance(generator, CompletionResponse):
|
||||||
await socket.send_multipart([identity, b"application/json",
|
await socket.send_multipart([identity,
|
||||||
|
context_type_json,
|
||||||
json.dumps(generator.model_dump()).encode('utf-8')])
|
json.dumps(generator.model_dump()).encode('utf-8')])
|
||||||
else:
|
else:
|
||||||
async for chunk in generator:
|
async for chunk in generator:
|
||||||
|
|||||||
Reference in New Issue
Block a user