Compare commits

...

50 Commits

Author SHA1 Message Date
220d694080 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-24 01:00:20 +00:00
70e06dd574 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-24 00:46:46 +00:00
7954461d4c updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 23:03:42 +00:00
a10da86677 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 22:56:53 +00:00
284d5df45b added __init__.py
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 22:50:20 +00:00
d5b0db449e added __init__.py
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 22:44:36 +00:00
66349c33a1 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 22:36:57 +00:00
28d0396ff1 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 21:54:04 +00:00
2f29ae383a added files
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 21:45:01 +00:00
cf64b0e6a7 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 21:44:14 +00:00
f51f182d64 pre-commit
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-23 20:18:50 +00:00
79e465f557 fix pre-commit
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-23 09:38:55 -04:00
2ba687d39f updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:52:06 -04:00
5d57896e2c updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:51:53 -04:00
f6f008ca1d cleanup
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:51:21 -04:00
24cbbe4778 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:50:48 -04:00
2fec6e0b5c working?
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:45:00 -04:00
47a3f26b2a updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 18:36:52 -04:00
144162fc8c Merge branch 'main' into rob-fixes 2025-03-22 17:12:53 -04:00
522279ebb9 Stash
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 17:12:21 -04:00
85687b43e7 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 17:00:46 -04:00
120bbdfd82 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 15:58:51 -04:00
2ceb7bc534 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 13:25:05 -04:00
9f7fb5ec84 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 13:22:00 -04:00
a8a621e419 updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
2025-03-22 13:11:50 -04:00
b89d89f456 fix rebase
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:32:21 +08:00
8355358fb3 add unlimited HWM
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:20:12 +08:00
c0b1443345 fix mypy
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:20:12 +08:00
d35dace985 refactor zmq msg to object
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:20:12 +08:00
912031ceb5 refactor disagg
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:20:12 +08:00
4f13e89143 fix SIM105
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
b9a7dbe769 remove default socket address value
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
0cb2e05256 change log level and fix some comments
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
d6945ecdf0 change disagg_prefill example to use zmq
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
298298f97d remove invalid zmq benchmark code
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
6c8fae82dd run format
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:19 +08:00
16ed827378 add benchmark shell
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:18:08 +08:00
8fa9df7987 run format.sh
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:57 +08:00
27c1afe88b fix ThreadProxy
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:57 +08:00
ee6607332e create proxy sockets in the proxy function for thread safety
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:57 +08:00
7fbf70db57 1. replace tpc:// with ipc:// \n 2. fix json response
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:57 +08:00
2c31e4c3ea Run yapf and ruff
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:57 +08:00
187f112ccd 1. fix mypy issue
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:44 +08:00
897db7b93d Replace zmq.asyncio.Context().term() with zmq.asyncio.Context().destroy(linger=0) for immediate termination
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:44 +08:00
b7ffb43792 update disagg_connect test_request.py
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:44 +08:00
6e1fba8a73 1. connect_parser set --prefill-addr and --decode-addr are required
2.To more accurately reflect its purpose, we will rename connect.py to disagg_connector.py.

Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:44 +08:00
bfde1688e7 add /v1/completions stream support
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:17:44 +08:00
905424ed65 add identity url headers
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:15:42 +08:00
5d20f389d6 add vllm connect cmd
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:15:42 +08:00
2a0cb78016 add test py
Signed-off-by: clark <panf2333@gmail.com>
2025-03-21 08:15:42 +08:00
8 changed files with 871 additions and 0 deletions

View File

@ -0,0 +1,123 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling with ZMQ
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.
set -xe
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# a function that waits vLLM connect to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# a function that waits vLLM disagg to start
wait_for_disagg_server() {
local log_file=$1
timeout 1200 bash -c "
until grep -q 'PDWorker is ready' $log_file; do
sleep 1
done" && return 0 || return 1
}
# You can also adjust --kv-ip and --kv-port for distributed inference.
MODEL=meta-llama/Llama-3.1-8B-Instruct
CONTROLLER_ADDR=controller.ipc
PREFILL_WORKER_ADDR=prefill.ipc
DECODE_WORKER_ADDR=decode.ipc
PORT=8001
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 python3 ../../vllm/entrypoints/disaggregated/worker.py \
--model $MODEL \
--controller-addr $CONTROLLER_ADDR \
--worker-addr $PREFILL_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 python3 ../../vllm/entrypoints/disaggregated/worker.py \
--model $MODEL \
--controller-addr $CONTROLLER_ADDR \
--worker-addr $DECODE_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > vllm_disagg_decode.log 2>&1 &
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - Send req to prefill instance, wait until complete.
# - Send req to decode instance, streaming tokens.
python3 ../../vllm/entrypoints/disaggregated/api_server.py \
--port $PORT \
--model $MODEL \
--controller-addr $CONTROLLER_ADDR \
--prefill-addr $PREFILL_WORKER_ADDR \
--decode-addr $DECODE_WORKER_ADDR &
# wait until prefill, decode instances and proxy are ready
wait_for_server $PORT
wait_for_disagg_server vllm_disagg_prefill.log
wait_for_disagg_server vllm_disagg_decode.log
# serve two example requests
output1=$(curl -X POST -s http://localhost:8001/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8001/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo ""
sleep 1
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""

View File

View File

@ -0,0 +1,364 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from collections.abc import AsyncGenerator, Mapping
from typing import Optional, Union
import msgspec
import zmq
import zmq.asyncio
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.disaggregated.protocol import (PDGenerationRequest,
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import Device
logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
class PDController(EngineClient):
"""
Controller that schedules work on the PDWorkers.
Conforms for the EngineClient protocol so it can
be wrapped with the OpenAI Server.
Two Phases:
* Send request to prefill worker, await ack.
* Send request to decode worker, stream responses.
KVSync happens directly between Engines,
handled by vLLM KVCacheTransfer.
[ OpenAI Server ]
|
[ PDController ]
|
[ zmq ]
|
[ PDWorker ] [ PDWorker ]
| |
[ Engine ] <-kv-> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler and service
discovery for the workers.
This PDController may be implemented as a K8s
controller. This is intended to be a prototype.
* TODO: better error handling
* TODO: support logprobs, multimodal, etc.
"""
def __init__(self, prefill_addr: str, decode_addr: str,
controller_addr: str, model_name: str):
# Request queues.
self.queues: dict[str, asyncio.Queue] = {}
# Serialization encoder.
self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication.
# TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context()
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
self.to_prefill.connect(prefill_addr)
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.connect(decode_addr)
self.controller_addr = controller_addr
self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor OAI Server to avoid needing this.
self.model_config = ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
task="generate",
seed=42)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor OAI Server to avoid needing this.
self.tokenizer = TokenizerGroup(
**dict(tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side))
def shutdown(self):
if (ctx := self.ctx) is not None:
ctx.destroy(linger=0)
if (task := self.log_running) is not None:
task.cancel()
if (task := self.output_handler) is not None:
task.cancel()
for path in self.ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
async def _run_log_running(self):
logger.info("Running requests: %d", len(self.queues))
await asyncio.sleep(10.)
async def _run_output_handler(self):
"""
Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
socket = self.ctx.socket(zmq.constants.PULL)
socket.bind(self.controller_addr)
while True:
res_type, res_data = await socket.recv_multipart()
if res_type == PDResponseType.FAILURE:
raise Exception("Failure Response from PDWorker.")
elif res_type == PDResponseType.GENERATION:
response = decoder.decode(res_data)
logger.debug("Got Response: %s", response.request_id)
self.queues[response.request_id].put_nowait(response)
else:
raise Exception("Unknown response type.")
except Exception as e:
# TODO: distinguish between fatal and non-fatal errors.
for q in self.queues.values():
q.put_nowait(e)
raise e
finally:
if socket is not None:
socket.close(linger=0)
async def _run_prefill(
self,
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
):
# Send request to the prefill instance.
req_bytes = self.encoder.encode(request)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_prefill.send_multipart(msg, copy=False)
# Await completion of the prefill.
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Prefill Response: %s", request.request_id)
async def _run_decode(
self,
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
) -> AsyncGenerator[PDGenerationResponse]:
# Send request to the decode instance.
req_bytes = self.encoder.encode(request)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_decode.send_multipart(msg, copy=False)
# Iterate response queue and yield each response to caller.
finished = False
while not finished:
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Decode Response: %s", request.request_id)
finished = response.finish_reason is not None
yield response
def _to_request_output(
self,
response: PDGenerationResponse,
prompt_token_ids: list[int],
) -> RequestOutput:
finished = response.finish_reason is not None
return RequestOutput(
request_id=response.request_id,
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
outputs=[
CompletionOutput(index=0,
text=response.text,
token_ids=response.token_ids,
cumulative_logprob=None,
logprobs=None,
finish_reason=response.finish_reason,
stop_reason=response.stop_reason)
],
finished=finished,
)
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput]:
# Start loops on first request.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix.
if "prompt_token_ids" not in prompt:
raise NotImplementedError(
"We currently only support TokensPrompt for P/D!")
if lora_request is not None:
raise NotImplementedError(
"We currently do not support LoRA for P/D!")
if trace_headers is not None:
raise NotImplementedError(
"We currently do not support tracing for P/D!")
if prompt_adapter_request is not None:
raise NotImplementedError(
"We currently do not support prompt adapter for P/D!")
if priority != 0:
raise NotImplementedError(
"We currently do not support priority for P/D!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
request = PDGenerationRequest(
request_id=request_id,
prompt_token_ids=prompt["prompt_token_ids"],
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes.
request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._run_decode(request, q):
yield self._to_request_output(pd_response,
prompt["prompt_token_ids"])
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
raise NotImplementedError
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError
async def abort(self, request_id: str) -> None:
raise NotImplementedError
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self) -> DecodingConfig:
raise NotImplementedError
async def get_input_preprocessor(self) -> InputPreprocessor:
raise NotImplementedError
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if lora_request is not None:
raise NotImplementedError(
"LoRA is not yet supported in the PDEngine.")
return self.tokenizer.get_lora_tokenizer(None)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
pass
async def start_profile(self) -> None:
raise NotImplementedError
async def stop_profile(self) -> None:
raise NotImplementedError
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
raise NotImplementedError
async def sleep(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up(self) -> None:
raise NotImplementedError
async def is_sleeping(self) -> bool:
return False
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError
@property
def errored(self) -> bool:
return False
def dead_error(self) -> Exception:
return Exception("PDController has failed.")
def is_running(self) -> bool:
return True
def is_stopped(self) -> bool:
return False

View File

@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import msgspec
import zmq
import zmq.asyncio
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
class PDWorker:
def __init__(
self,
engine: EngineClient,
worker_addr: str,
controller_addr: str,
):
"""
PDWorker
* Wrapper around AsyncLLM to handle converting PRRequests
to PDResponse and sending back to the PDConroller.
* Leverages ZMQ for communication with PDConroller. We may
expand this in the future.
"""
# Engine.
self.engine = engine
# ZMQ IPC.
self.worker_addr = f"ipc://{worker_addr}"
self.controller_addr = f"ipc://{controller_addr}"
self.ctx = zmq.asyncio.Context()
self.from_controller = self.ctx.socket(zmq.constants.PULL)
self.from_controller.bind(self.worker_addr)
self.to_controller = self.ctx.socket(zmq.constants.PUSH)
self.to_controller.connect(self.controller_addr)
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
self.encoder = msgspec.msgpack.Encoder()
# Active Requests.
self.running_requests: set[asyncio.Task] = set()
def shutdown(self):
if hasattr(self, "ctx"):
self.ctx.destroy()
if hasattr(self, "running_requests"):
for running_request in self.running_requests:
running_request.cancel()
if hasattr(self, "controller_addr"):
ipc_paths = [self.worker_addr, self.controller_addr]
for ipc_path in ipc_paths:
socket_path = ipc_path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
async def run_busy_loop(self):
"""
main loop:
1) wait for a request from the PDConroller
2) handle the request
"""
logger.info("PDWorker is ready To handle requests.")
poller = zmq.asyncio.Poller()
poller.register(self.from_controller, zmq.POLLIN)
while True:
# 1) Get request from the Connector.
req_type, req_data = await self.from_controller.recv_multipart()
# 2) Handle the request.
await self._handle_request(req_type, req_data)
async def _handle_request(self, req_type: bytes, req_data: bytes):
"""
request handler:
1) parse the request type
2) call the appropriate handler for the request type
"""
if req_type == PDRequestType.GENERATION:
req = self.decode_generation.decode(req_data)
await self._generation_handler(req)
elif req_type == PDRequestType.ABORT:
req = self.decode_abort.decode(req_data)
await self._abort_handler(req)
else:
raise Exception(f"Unknown Request Type: {req_type}.")
async def _generation_handler(self, req: PDGenerationRequest):
"""
Handle a PDGenerationRequest by launching a task.
"""
task = asyncio.create_task(self._generate(req))
self.running_requests.add(task)
task.add_done_callback(self.running_requests.discard)
async def _abort_handler(self, req: PDGenerationRequest):
"""
Handle a PDAbortRequest by cancelling the running task.
The _generate coro aborts in the Engine.
"""
# Convert running_requests set() into a dict(), keyed
# by request_id. Cancel the task when an abort comes in.
# Then update the _generate coroutine to handle a
# cancel error by aborting in the Engine.
pass
async def _generate(self, req: PDGenerationRequest):
"""
Handle a single PDGenerationRequest:
* 1) submit request to AsyncLLM
* 2) iterate the RequestOutputs
* 3) convert RequestOutput --> PDResponse
* 4) serialize and send to PDConroller
"""
request_id = req.request_id
# 1) Submit request to Engine.
generator = self.engine.generate(
prompt={"prompt_token_ids": req.prompt_token_ids},
sampling_params=req.sampling_params,
request_id=request_id)
# 2) Iterate RequestOutputs.
async for request_output in generator:
# 3) Convert RequestOutput --> PDResponse.
response = PDGenerationResponse.from_request_output(request_output)
# 4) Serialize and send to PDConroller.
response_bytes = self.encoder.encode(response)
msg = (PDResponseType.GENERATION, response_bytes)
await self.to_controller.send_multipart(msg, copy=False)

View File

@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import msgspec
from vllm import SamplingParams
from vllm.outputs import RequestOutput
# NOTE FOR DEVELOPERS:
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
class PDRequestType:
GENERATION = b'\x00'
ABORT = b'\x01'
class PDGenerationRequest(msgspec.Struct):
request_id: str
prompt_token_ids: list[int]
sampling_params: SamplingParams
# TODO: support multimodal inputs.
class PDAbortRequest(msgspec.Struct):
request_id: str
class PDResponseType:
GENERATION = b'\x00'
FAILURE = b'\x01'
class PDGenerationResponse(msgspec.Struct):
request_id: str
text: str
token_ids: list[int]
finish_reason: Optional[str] = None
stop_reason: Optional[str] = None
# TODO: support full protocol.
logprobs = None
@classmethod
def from_request_output(
self, request_output: RequestOutput) -> "PDGenerationResponse":
assert len(request_output.outputs) == 1, "Only support N=1 right now."
out = request_output.outputs[0]
return PDGenerationResponse(
request_id=request_output.request_id,
text=out.text,
token_ids=out.token_ids,
finish_reason=out.finish_reason,
stop_reason=out.stop_reason,
)

View File

@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
"""
Toy API Server for prototyping.
Once the PDController is more mature and we clean up
the OpenAI Server at bit, we can put the PDController
directly inside and launch with vllm serve.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.disaggregated.pd_controller import PDController
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser, set_ulimit
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
app = FastAPI()
@app.get("/v1/models")
async def show_available_models(raw_request: Request):
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion # noqa: E501
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@asynccontextmanager
async def controller_ctx(prefill_addr: str, decode_addr: str,
controller_addr: str,
model_name: str) -> AsyncIterator[PDController]:
c = PDController(prefill_addr, decode_addr, controller_addr, model_name)
yield c
c.shutdown()
async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregated Connector Start %s %s", args,
uvicorn_kwargs)
# Avoid dropping requests under high concurrency.
set_ulimit()
# IPC Paths.
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
controller_addr = f"ipc://{args.controller_addr}"
# Start Engine.
async with controller_ctx(prefill_addr=prefill_addr,
decode_addr=decode_addr,
controller_addr=controller_addr,
model_name=args.model) as engine_client:
# Initialize App State.
model_config = await engine_client.get_model_config()
app.state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=[
BaseModelPath(name=args.served_model_name or args.model,
model_path=args.model)
],
)
app.state.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine_client,
model_config=model_config,
models=app.state.openai_serving_models,
request_logger=None,
)
# Run Server.
config = uvicorn.Config(app, host=args.host, port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible P/D Server.")
parser.add_argument("--host",
type=str,
default="0.0.0.0",
help="The host of the HTTP server.")
parser.add_argument("--port",
type=int,
default=8000,
help="The port of the HTTP server.")
parser.add_argument("--model",
type=str,
required=True,
help="The path to the model.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The served name of the model.")
parser.add_argument("--controller-addr",
type=str,
required=True,
help="The zmq ipc controller address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
args = parser.parse_args()
uvloop.run(main(args))

View File

@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
import uvloop
from vllm.disaggregated.pd_worker import PDWorker
from vllm.engine.async_llm_engine import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.disaggregated.worker")
async def run(args, engine: EngineClient):
try:
worker = PDWorker(engine=engine,
worker_addr=args.worker_addr,
controller_addr=args.controller_addr)
await worker.run_busy_loop()
finally:
worker.shutdown()
async def main(args) -> None:
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
logger.info("Args: %s", args)
async with build_async_engine_client(args) as engine:
await run(args, engine)
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument('--controller-addr',
type=str,
required=True,
help='The address of the controller.')
parser.add_argument('--worker-addr',
type=str,
required=True,
help='The address of the worker.')
parser.add_argument('--disable-frontend-multiprocessing',
action="store_true",
help='Disable MQLLMEngine for AsyncLLMEngine.')
AsyncEngineArgs.add_cli_args(parser)
uvloop.run(main(parser.parse_args()))