Files
vllm-ascend/examples/disaggregated_prefill/p2p_disaggrefated_prefill_proxy.py
Mengqing Cao 6eddbd2521 [CI/UT][PD Disaggreate] Initialize PD Disaggreate UT (#889)
Initialize PD Disaggreate UT

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-05-29 10:17:12 +08:00

194 lines
6.8 KiB
Python

import os
import socket
import threading
import uuid
import aiohttp
import msgpack # type: ignore
import zmq
from quart import Quart, make_response, request
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a prefill register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[
data["http_address"]] = data["zmq_address"]
print(
"Get a decode register with http_addr %s and zmq_addr %s",
data["http_address"],
data["zmq_address"],
)
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context() # type: ignore
router_socket = context.socket(zmq.ROUTER) # type: ignore
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() # type: ignore
poller.register(router_socket, zmq.POLLIN) # type: ignore
_listener_thread = threading.Thread(target=_listen_for_register,
args=[poller, router_socket],
daemon=True)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global prefill_instances
global prefill_cv
with prefill_cv:
if len(prefill_instances) > 1:
print(
"Found more than 1 Prefill instances. Currently we only support 1P1D, so only"
f"the first Prefill instance({list(prefill_instances.keys())[0]}) will be used!"
)
if len(prefill_instances) == 0:
res_str = (
"No Prefill instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Prefill vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# prefill_addr, prefill_zmq_addr = random.choice(
# list(prefill_instances.items()))
prefill_addr, prefill_zmq_addr = list(prefill_instances.items())[0]
print(
"handle_request, prefill_addr: %s, zmq_addr: %s",
prefill_addr,
prefill_zmq_addr,
)
global decode_instances
global decode_cv
with decode_cv:
if len(decode_instances) > 1:
print(
"Found more than 1 Decode instances. Currently we only support 1P1D, so only"
f"the first Decode instance({list(decode_instances.keys())[0]}) will be used!"
)
if len(decode_instances) == 0:
res_str = (
"No Decode instances has been registered to proxy. Please confirm that you have successfully"
" and correctly started a Decode vLLM instance.")
print(res_str)
response = await make_response(res_str)
return response
# decode_addr, decode_zmq_addr = random.choice(
# list(decode_instances.items()))
decode_addr, decode_zmq_addr = list(decode_instances.items())[0]
print(
"handle_request, decode_addr: %s, zmq_addr: %s",
decode_addr,
decode_zmq_addr,
)
request_id = f"___prefill_addr_{prefill_addr}___decode_addr_{decode_addr}_{random_uuid()}"
# finish prefill
async for _ in forward_request(f"http://{prefill_addr}/v1/completions",
prefill_request, request_id):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions",
original_request_data,
request_id,
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="args of disaggregated-prefill proxy")
parser.add_argument("--http-port", type=int, default=10001)
parser.add_argument("--register-port", type=int, default=10002)
args = parser.parse_args()
t = start_service_discovery("0.0.0.0", args.register_port)
app.run(host="0.0.0.0", port=args.http_port)
t.join()