Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2025-03-22 17:12:21 -04:00
parent 85687b43e7
commit 522279ebb9
4 changed files with 22 additions and 560 deletions

View File

@ -44,18 +44,27 @@ wait_for_disagg_server() {
# You can also adjust --kv-ip and --kv-port for distributed inference.
MODEL=meta-llama/Llama-3.1-8B-Instruct
CONNECTOR_ADDR=connectoripc
PREFILL_WORKER_ADDR=prefillipc
DECODE_WORKER_ADDR=prefillipc
PORT=8000
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
--zmq-server-addr testipc0 \
CUDA_VISIBLE_DEVICES=0 python3 -m vllm.entrypoints.disaggregated.worker \
--model $MODEL \
--connector-addr $CONNECTOR_ADDR \
--worker-addr $PREFILL_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
--zmq-server-addr testipc1 \
CUDA_VISIBLE_DEVICES=1 python3 -m vllm.entrypoints.disaggregated.worker \
--model $MODEL \
--connector-addr $CONNECTOR_ADDR \
--worker-addr $DECODE_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
@ -63,16 +72,17 @@ CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance (via zmq addr testipc1)
vllm connect --port 8000 \
--prefill-addr testipc0 \
--decode-addr testipc1 &
# - Send req to prefill instance, wait until complete.
# - Send req to decode instance, streaming tokens.
python3 -m vllm.entrypoints.disaggregated.connector \
--port $PORT \
--model $MODEL \
--connector-addr $CONNECTOR_ADDR \
--prefill-addr $PREFILL_WORKER_ADDR \
--decode-addr $DECODE_WORKER_ADDR
# wait until prefill, decode instances and proxy are ready
wait_for_server 8000
wait_for_server $PORT
wait_for_disagg_server vllm_disagg_prefill.log
wait_for_disagg_server vllm_disagg_decode.log

View File

@ -1,126 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import uvicorn
import uvloop
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.disaggregated.pd_engine import PDEngine
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
from vllm.entrypoints.openai.protocol import (
CompletionResponse, ErrorResponse)
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
app = FastAPI()
@app.get("/v1/models")
async def show_available_models(raw_request: Request):
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@asynccontextmanager
async def pd_engine_client_ctx_manager(
model_name: str,
prefill_addr: str,
decode_addr: str,
connector_addr: str) -> AsyncIterator[PDEngine]:
engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr)
yield engine
engine.shutdown()
async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector Start %s %s", args,
uvicorn_kwargs)
# Avoid dropping requests under high concurrency.
set_ulimit()
# IPC Paths.
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
# are not using pickle to avoid RCE security flaw.
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
connector_addr = f"ipc://{args.connector_addr}"
# Start Engine.
with pd_engine_client_ctx_manager(
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
# Initialize App State.
model_config = await engine_client.get_model_config()
app.state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=[BaseModelPath(
name=args.served_model_name or args.model,
model_path=args.model)
],
)
app.state.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine_client,
model_config=model_config,
models=app.state.openai_serving_models,
request_logger=None,
)
# Run Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible P/D Server.")
parser.add_argument("--host",
type=str,
default="0.0.0.0",
help="The host of the HTTP server.")
parser.add_argument("--port",
type=int,
default=8001,
help="The port of the HTTP server.")
parser.add_argument("--model",
type=str,
required=True,
help="The path to the model.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The served name of the model.")
parser.add_argument("--connector-addr",
type=str,
required=True,
help="The zmq ipc connector address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
args = parser.parse_args()
uvloop.run(main(args))

View File

@ -1,296 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import msgspec
import os
from collections.abc import AsyncGenerator
from typing import Dict, List, Mapping, Optional
import uvloop
import zmq
import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import Device, make_zmq_socket
logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
class PDEngine:
"""
PDEngine:
Equiavlent of AsyncLLM for P/D. Assumes there is
a Prefill and Decode service already running.
* TODO: actually handle errors and failure.
* TODO: support more than just text input.
* TODO: move under vllm/v1/engine one past prototype.
"""
def __init__(
self,
prefill_addr: str,
decode_addr: str,
connector_addr: str,
model_name: str
):
# Request queues.
self.queues: Dict[str, asyncio.Queue] = {}
# Serialization encoder.
self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication..
self.ctx = zmq.asyncio.Context()
self.to_decode = make_zmq_socket(
self.ctx, f"{decode_addr}", zmq.constants.PUSH)
self.to_prefill = make_zmq_socket(
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
self.connector_addr = connector_addr
self.decode_addr = decode_addr
self.prefill_addr = prefill_addr
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
self.model_config = ModelConfig(
model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
seed=42
)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side)
self.tokenizer = TokenizerGroup(**init_kwargs)
def shutdown(self):
if (ctx := self.ctx) is not None:
ctx.destroy(linger=0)
if (task := self.log_running) is not None:
task.cancel()
if (task := self.output_handler) is not None:
task.cancel()
ipc_paths = [
self.connector_addr, self.decode_addr, self.prefill_addr
]
for path in ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
async def _run_log_running(self):
logger.info("Running requests: %d", len(self.queues))
await asyncio.sleep(10.)
async def _run_output_handler(self, socket: zmq.asyncio.Socket):
"""
Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
socket = make_zmq_socket(
self.ctx, self.connector_addr, zmq.constants.PULL)
while True:
reponse_bytes = await socket.recv().buffer
response = decoder.decode(reponse_bytes)
self.queues[response.request_id].put_nowait(response)
except:
# TODO: actually handle failure and shutdown.
raise
finally:
if socket is not None:
socket.close(linger=0)
async def _prefill(self,
request: PDRequest,
q: asyncio.Queue[PDResponse]) -> PDResponse:
# Send request to the prefill instance.
req_bytes = self.encoder(request)
await self.to_prefill.send(req_bytes, copy=False)
# Wait for the prefill to be done.
response = await q.get()
assert response.request_id == request.request_id
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.")
return response
async def _decode(self,
request: PDRequest,
q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]:
# Send request to the decode instance.
req_bytes = self.encoder(request)
await self.to_decode.send(req_bytes, copy=False)
# Iterate response queue and yield each response to caller..
finished = False
while not finished:
response = await q.get()
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.")
finished = response.finish_reason is not None
yield response
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[PDResponse]:
# Start loops on first request.
if self.output_handler is None:
self.output_handler = asyncio.create_task(self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix.
if not isinstance(prompt, str):
raise NotImplementedError(
"We currently only support text for P/D!")
if lora_request is not None:
raise NotImplementedError(
"We currently do not suppport LoRA for P/D!")
if trace_headers is not None:
raise NotImplementedError(
"We currently do not suppport tracing for P/D!")
if prompt_adapter_request is not None:
raise NotImplementedError(
"We currently do not suppport prompt adapter for P/D!")
if priority != 0:
raise NotImplementedError(
"We currently do not support priority for P/D!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
request = PDRequest(request_id, prompt, sampling_params)
request.sampling_params.max_tokens = 1
response = await self._prefill(request, q)
yield response
# (2) Perform the Decodes.
request.sampling_params.max_tokens = original_max_tokens
async for response in self._decode(request, q):
yield response
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
raise NotImplementedError
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError
async def abort(self, request_id: str) -> None:
raise NotImplementedError
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self) -> DecodingConfig:
raise NotImplementedError
async def get_input_preprocessor(self) -> InputPreprocessor:
raise NotImplementedError
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if lora_request is not None:
raise NotImplementedError(
"LoRA is not yet supported in the PDEngine.")
return self.tokenizer.get_lora_tokenizer(None)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
pass
async def start_profile(self) -> None:
raise NotImplementedError
async def stop_profile(self) -> None:
raise NotImplementedError
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
raise NotImplementedError
async def sleep(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up(self) -> None:
raise NotImplementedError
async def is_sleeping(self) -> bool:
False
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError

View File

@ -1,126 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import msgpack
import signal
import uvloop
from typing import Optional
import zmq
import zmq.asyncio
from vllm.engine.async_llm_engine import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
async def handle_request(
request: PDRequest,
engine: EngineClient,
socket: zmq.asyncio.Socket,
encoder: msgpack.Encoder,
) -> None:
request_id = request.request_id
try:
# 1) Generate RequestOutputs.
async for request_output in engine.generate(
prompt=request.prompt_token_ids,
sampling_params=request.sampling_params,
request_id=request_id):
assert len(request_output.outputs) == 0, "Only support N=1 right now."
out = request_output.outputs[0]
# 2) Convert RequestOutput --> PDResponse.
response = PDResponse(
request_id=request_id,
success=True,
text=out.text,
token_ids=out.token_ids,
finish_reasons=out.finish_reason,
stop_reason=out.stop_reason,
)
response_bytes = encoder(response)
# 3) Send to Connector.
await socket.send(response_bytes, copy=False)
except Exception as e:
# TODO: actual error handling.
logger.error("Exception in Worker Routine: %s request_id: %s", e,
request_id)
response = PDResponse(request_id=request_id, success=False)
response_bytes = encoder(response)
await socket.send(response, copy=False)
async def run_server(args, engine: EngineClient):
"""Get Requests and Handle Them."""
running_requests: set[asyncio.Task] = set()
decoder = msgpack.Decoder(PDRequest)
encoder = msgpack.Encoder()
ctx: Optional[zmq.asyncio.Context] = None
try:
# IPC Setup.
ctx = zmq.asyncio.Context()
from_connector = make_zmq_socket(
ctx, f"ipc://{args.server_addr}", zmq.constants.PULL)
to_connector = make_zmq_socket(
ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH)
# Main Loop.
while True:
# 1) Get request from the Connector.
pd_request_bytes = await from_connector.recv().buffer
pd_request = decoder(pd_request_bytes)
# 2) Launch a coroutine to handle the request.
task = asyncio.create_task(handle_request(
pd_request, engine, to_connector, encoder))
running_requests.add(task)
task.add_done_callback(running_requests.discard)
except KeyboardInterrupt:
logger.debug("Worker server loop interrupted.")
finally:
for task in running_requests:
task.cancel()
if ctx is not None:
ctx.destroy(linger=0)
async def main(args) -> None:
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
logger.info("args: %s", args)
# Workaround to avoid footguns where uvicorn drops requests
# with too many concurrent requests active due to ulimit.
set_ulimit()
# Interrupt on sigterm during initialization.
def signal_handler(*_) -> None:
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine:
await run_server(args, engine)
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument('--connector-addr',
type=str,
required=True,
help='The address of the connector.')
parser.add_argument('--worker-addr',
type=str,
required=True,
help='The address of the worker.')
AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
uvloop.run(main(args))