mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
50 Commits
ci/build/2
...
rob-fixes
Author | SHA1 | Date | |
---|---|---|---|
220d694080 | |||
70e06dd574 | |||
7954461d4c | |||
a10da86677 | |||
284d5df45b | |||
d5b0db449e | |||
66349c33a1 | |||
28d0396ff1 | |||
2f29ae383a | |||
cf64b0e6a7 | |||
f51f182d64 | |||
79e465f557 | |||
2ba687d39f | |||
5d57896e2c | |||
f6f008ca1d | |||
24cbbe4778 | |||
2fec6e0b5c | |||
47a3f26b2a | |||
144162fc8c | |||
522279ebb9 | |||
85687b43e7 | |||
120bbdfd82 | |||
2ceb7bc534 | |||
9f7fb5ec84 | |||
a8a621e419 | |||
b89d89f456 | |||
8355358fb3 | |||
c0b1443345 | |||
d35dace985 | |||
912031ceb5 | |||
4f13e89143 | |||
b9a7dbe769 | |||
0cb2e05256 | |||
d6945ecdf0 | |||
298298f97d | |||
6c8fae82dd | |||
16ed827378 | |||
8fa9df7987 | |||
27c1afe88b | |||
ee6607332e | |||
7fbf70db57 | |||
2c31e4c3ea | |||
187f112ccd | |||
897db7b93d | |||
b7ffb43792 | |||
6e1fba8a73 | |||
bfde1688e7 | |||
905424ed65 | |||
5d20f389d6 | |||
2a0cb78016 |
123
examples/online_serving/disaggregated_prefill_zmq.sh
Normal file
123
examples/online_serving/disaggregated_prefill_zmq.sh
Normal file
@ -0,0 +1,123 @@
|
||||
#!/bin/bash
|
||||
# This file demonstrates the example usage of disaggregated prefilling with ZMQ
|
||||
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
|
||||
# and then transfer the KV cache between them.
|
||||
|
||||
set -xe
|
||||
|
||||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
||||
sleep 1
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'cleanup' INT
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9
|
||||
pkill -f python
|
||||
echo "Cleanup complete. Exiting."
|
||||
exit 0
|
||||
}
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
# a function that waits vLLM connect to start
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
# a function that waits vLLM disagg to start
|
||||
wait_for_disagg_server() {
|
||||
local log_file=$1
|
||||
timeout 1200 bash -c "
|
||||
until grep -q 'PDWorker is ready' $log_file; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
CONTROLLER_ADDR=controller.ipc
|
||||
PREFILL_WORKER_ADDR=prefill.ipc
|
||||
DECODE_WORKER_ADDR=decode.ipc
|
||||
PORT=8001
|
||||
|
||||
# prefilling instance, which is the KV producer
|
||||
CUDA_VISIBLE_DEVICES=0 python3 ../../vllm/entrypoints/disaggregated/worker.py \
|
||||
--model $MODEL \
|
||||
--controller-addr $CONTROLLER_ADDR \
|
||||
--worker-addr $PREFILL_WORKER_ADDR \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
|
||||
|
||||
# decoding instance, which is the KV consumer
|
||||
CUDA_VISIBLE_DEVICES=1 python3 ../../vllm/entrypoints/disaggregated/worker.py \
|
||||
--model $MODEL \
|
||||
--controller-addr $CONTROLLER_ADDR \
|
||||
--worker-addr $DECODE_WORKER_ADDR \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > vllm_disagg_decode.log 2>&1 &
|
||||
|
||||
# launch a proxy server that opens the service at port 8000
|
||||
# the workflow of this proxy:
|
||||
# - Send req to prefill instance, wait until complete.
|
||||
# - Send req to decode instance, streaming tokens.
|
||||
python3 ../../vllm/entrypoints/disaggregated/api_server.py \
|
||||
--port $PORT \
|
||||
--model $MODEL \
|
||||
--controller-addr $CONTROLLER_ADDR \
|
||||
--prefill-addr $PREFILL_WORKER_ADDR \
|
||||
--decode-addr $DECODE_WORKER_ADDR &
|
||||
|
||||
# wait until prefill, decode instances and proxy are ready
|
||||
wait_for_server $PORT
|
||||
wait_for_disagg_server vllm_disagg_prefill.log
|
||||
wait_for_disagg_server vllm_disagg_decode.log
|
||||
|
||||
# serve two example requests
|
||||
output1=$(curl -X POST -s http://localhost:8001/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
output2=$(curl -X POST -s http://localhost:8001/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"prompt": "Santa Clara is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9
|
||||
pkill -f python
|
||||
|
||||
echo ""
|
||||
|
||||
sleep 1
|
||||
|
||||
# Print the outputs of the curl requests
|
||||
echo ""
|
||||
echo "Output of first request: $output1"
|
||||
echo "Output of second request: $output2"
|
||||
|
||||
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
|
||||
echo ""
|
0
vllm/disaggregated/__init__.py
Normal file
0
vllm/disaggregated/__init__.py
Normal file
364
vllm/disaggregated/pd_controller.py
Normal file
364
vllm/disaggregated/pd_controller.py
Normal file
@ -0,0 +1,364 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.disaggregated.protocol import (PDGenerationRequest,
|
||||
PDGenerationResponse, PDRequestType,
|
||||
PDResponseType)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 32000
|
||||
|
||||
|
||||
class PDController(EngineClient):
|
||||
"""
|
||||
Controller that schedules work on the PDWorkers.
|
||||
|
||||
Conforms for the EngineClient protocol so it can
|
||||
be wrapped with the OpenAI Server.
|
||||
|
||||
Two Phases:
|
||||
* Send request to prefill worker, await ack.
|
||||
* Send request to decode worker, stream responses.
|
||||
|
||||
KVSync happens directly between Engines,
|
||||
handled by vLLM KVCacheTransfer.
|
||||
|
||||
[ OpenAI Server ]
|
||||
|
|
||||
[ PDController ]
|
||||
|
|
||||
[ zmq ]
|
||||
|
|
||||
[ PDWorker ] [ PDWorker ]
|
||||
| |
|
||||
[ Engine ] <-kv-> [ Engine ]
|
||||
|
||||
After PR #12957, we will support xPyD, so we will
|
||||
also need to implement a scheduler and service
|
||||
discovery for the workers.
|
||||
|
||||
This PDController may be implemented as a K8s
|
||||
controller. This is intended to be a prototype.
|
||||
|
||||
* TODO: better error handling
|
||||
* TODO: support logprobs, multimodal, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||
controller_addr: str, model_name: str):
|
||||
# Request queues.
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
|
||||
# Serialization encoder.
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
|
||||
# ZMQ communication.
|
||||
# TODO: once https://github.com/vllm-project/vllm/pull/12957
|
||||
# lands, do service discovery to scale out workers.
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_prefill.connect(prefill_addr)
|
||||
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_decode.connect(decode_addr)
|
||||
self.controller_addr = controller_addr
|
||||
self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
self.log_running: Optional[asyncio.Task] = None
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor OAI Server to avoid needing this.
|
||||
self.model_config = ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="auto",
|
||||
task="generate",
|
||||
seed=42)
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor OAI Server to avoid needing this.
|
||||
self.tokenizer = TokenizerGroup(
|
||||
**dict(tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=False,
|
||||
max_num_seqs=1024,
|
||||
max_loras=0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision,
|
||||
truncation_side=self.model_config.truncation_side))
|
||||
|
||||
def shutdown(self):
|
||||
if (ctx := self.ctx) is not None:
|
||||
ctx.destroy(linger=0)
|
||||
if (task := self.log_running) is not None:
|
||||
task.cancel()
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
for path in self.ipc_paths:
|
||||
socket_path = path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
async def _run_log_running(self):
|
||||
logger.info("Running requests: %d", len(self.queues))
|
||||
await asyncio.sleep(10.)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""
|
||||
Pull responses from Decode + Prefill engines and
|
||||
distribute back to the generate() tasks.
|
||||
"""
|
||||
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
|
||||
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
try:
|
||||
socket = self.ctx.socket(zmq.constants.PULL)
|
||||
socket.bind(self.controller_addr)
|
||||
|
||||
while True:
|
||||
res_type, res_data = await socket.recv_multipart()
|
||||
if res_type == PDResponseType.FAILURE:
|
||||
raise Exception("Failure Response from PDWorker.")
|
||||
elif res_type == PDResponseType.GENERATION:
|
||||
response = decoder.decode(res_data)
|
||||
logger.debug("Got Response: %s", response.request_id)
|
||||
self.queues[response.request_id].put_nowait(response)
|
||||
else:
|
||||
raise Exception("Unknown response type.")
|
||||
except Exception as e:
|
||||
# TODO: distinguish between fatal and non-fatal errors.
|
||||
for q in self.queues.values():
|
||||
q.put_nowait(e)
|
||||
raise e
|
||||
finally:
|
||||
if socket is not None:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _run_prefill(
|
||||
self,
|
||||
request: PDGenerationRequest,
|
||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||
):
|
||||
# Send request to the prefill instance.
|
||||
req_bytes = self.encoder.encode(request)
|
||||
msg = (PDRequestType.GENERATION, req_bytes)
|
||||
await self.to_prefill.send_multipart(msg, copy=False)
|
||||
|
||||
# Await completion of the prefill.
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Prefill Response: %s", request.request_id)
|
||||
|
||||
async def _run_decode(
|
||||
self,
|
||||
request: PDGenerationRequest,
|
||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||
) -> AsyncGenerator[PDGenerationResponse]:
|
||||
# Send request to the decode instance.
|
||||
req_bytes = self.encoder.encode(request)
|
||||
msg = (PDRequestType.GENERATION, req_bytes)
|
||||
await self.to_decode.send_multipart(msg, copy=False)
|
||||
|
||||
# Iterate response queue and yield each response to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
response = await q.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
logger.debug("Decode Response: %s", request.request_id)
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
def _to_request_output(
|
||||
self,
|
||||
response: PDGenerationResponse,
|
||||
prompt_token_ids: list[int],
|
||||
) -> RequestOutput:
|
||||
finished = response.finish_reason is not None
|
||||
return RequestOutput(
|
||||
request_id=response.request_id,
|
||||
prompt=None,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(index=0,
|
||||
text=response.text,
|
||||
token_ids=response.token_ids,
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=response.finish_reason,
|
||||
stop_reason=response.stop_reason)
|
||||
],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput]:
|
||||
# Start loops on first request.
|
||||
if self.output_handler is None:
|
||||
self.output_handler = asyncio.create_task(
|
||||
self._run_output_handler())
|
||||
self.log_running = asyncio.create_task(self._run_log_running())
|
||||
|
||||
# TODO: Expand to support the full matrix.
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise NotImplementedError(
|
||||
"We currently only support TokensPrompt for P/D!")
|
||||
if lora_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support LoRA for P/D!")
|
||||
if trace_headers is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support tracing for P/D!")
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support prompt adapter for P/D!")
|
||||
if priority != 0:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support priority for P/D!")
|
||||
if request_id in self.queues:
|
||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||
|
||||
# Queue to gather output from output_handler.
|
||||
q = asyncio.Queue()
|
||||
self.queues[request_id] = q
|
||||
|
||||
# (1) Perform the Prefill.
|
||||
original_max_tokens = sampling_params.max_tokens
|
||||
request = PDGenerationRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt["prompt_token_ids"],
|
||||
sampling_params=sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
pd_response = await self._run_prefill(request, q)
|
||||
|
||||
# (2) Perform the Decodes.
|
||||
request.sampling_params.max_tokens = original_max_tokens
|
||||
async for pd_response in self._run_decode(request, q):
|
||||
yield self._to_request_output(pd_response,
|
||||
prompt["prompt_token_ids"])
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if lora_request is not None:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not yet supported in the PDEngine.")
|
||||
return self.tokenizer.get_lora_tokenizer(None)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[list[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def check_health(self) -> None:
|
||||
pass
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def wake_up(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return False
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return False
|
||||
|
||||
def dead_error(self) -> Exception:
|
||||
return Exception("PDController has failed.")
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
return False
|
143
vllm/disaggregated/pd_worker.py
Normal file
143
vllm/disaggregated/pd_worker.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
|
||||
PDGenerationResponse, PDRequestType,
|
||||
PDResponseType)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PDWorker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: EngineClient,
|
||||
worker_addr: str,
|
||||
controller_addr: str,
|
||||
):
|
||||
"""
|
||||
PDWorker
|
||||
* Wrapper around AsyncLLM to handle converting PRRequests
|
||||
to PDResponse and sending back to the PDConroller.
|
||||
* Leverages ZMQ for communication with PDConroller. We may
|
||||
expand this in the future.
|
||||
"""
|
||||
# Engine.
|
||||
self.engine = engine
|
||||
|
||||
# ZMQ IPC.
|
||||
self.worker_addr = f"ipc://{worker_addr}"
|
||||
self.controller_addr = f"ipc://{controller_addr}"
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.from_controller = self.ctx.socket(zmq.constants.PULL)
|
||||
self.from_controller.bind(self.worker_addr)
|
||||
self.to_controller = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.to_controller.connect(self.controller_addr)
|
||||
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
|
||||
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
|
||||
# Active Requests.
|
||||
self.running_requests: set[asyncio.Task] = set()
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, "ctx"):
|
||||
self.ctx.destroy()
|
||||
|
||||
if hasattr(self, "running_requests"):
|
||||
for running_request in self.running_requests:
|
||||
running_request.cancel()
|
||||
|
||||
if hasattr(self, "controller_addr"):
|
||||
ipc_paths = [self.worker_addr, self.controller_addr]
|
||||
for ipc_path in ipc_paths:
|
||||
socket_path = ipc_path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
async def run_busy_loop(self):
|
||||
"""
|
||||
main loop:
|
||||
1) wait for a request from the PDConroller
|
||||
2) handle the request
|
||||
"""
|
||||
logger.info("PDWorker is ready To handle requests.")
|
||||
|
||||
poller = zmq.asyncio.Poller()
|
||||
poller.register(self.from_controller, zmq.POLLIN)
|
||||
|
||||
while True:
|
||||
# 1) Get request from the Connector.
|
||||
req_type, req_data = await self.from_controller.recv_multipart()
|
||||
|
||||
# 2) Handle the request.
|
||||
await self._handle_request(req_type, req_data)
|
||||
|
||||
async def _handle_request(self, req_type: bytes, req_data: bytes):
|
||||
"""
|
||||
request handler:
|
||||
1) parse the request type
|
||||
2) call the appropriate handler for the request type
|
||||
"""
|
||||
if req_type == PDRequestType.GENERATION:
|
||||
req = self.decode_generation.decode(req_data)
|
||||
await self._generation_handler(req)
|
||||
elif req_type == PDRequestType.ABORT:
|
||||
req = self.decode_abort.decode(req_data)
|
||||
await self._abort_handler(req)
|
||||
else:
|
||||
raise Exception(f"Unknown Request Type: {req_type}.")
|
||||
|
||||
async def _generation_handler(self, req: PDGenerationRequest):
|
||||
"""
|
||||
Handle a PDGenerationRequest by launching a task.
|
||||
"""
|
||||
task = asyncio.create_task(self._generate(req))
|
||||
self.running_requests.add(task)
|
||||
task.add_done_callback(self.running_requests.discard)
|
||||
|
||||
async def _abort_handler(self, req: PDGenerationRequest):
|
||||
"""
|
||||
Handle a PDAbortRequest by cancelling the running task.
|
||||
The _generate coro aborts in the Engine.
|
||||
"""
|
||||
# Convert running_requests set() into a dict(), keyed
|
||||
# by request_id. Cancel the task when an abort comes in.
|
||||
# Then update the _generate coroutine to handle a
|
||||
# cancel error by aborting in the Engine.
|
||||
pass
|
||||
|
||||
async def _generate(self, req: PDGenerationRequest):
|
||||
"""
|
||||
Handle a single PDGenerationRequest:
|
||||
* 1) submit request to AsyncLLM
|
||||
* 2) iterate the RequestOutputs
|
||||
* 3) convert RequestOutput --> PDResponse
|
||||
* 4) serialize and send to PDConroller
|
||||
"""
|
||||
request_id = req.request_id
|
||||
|
||||
# 1) Submit request to Engine.
|
||||
generator = self.engine.generate(
|
||||
prompt={"prompt_token_ids": req.prompt_token_ids},
|
||||
sampling_params=req.sampling_params,
|
||||
request_id=request_id)
|
||||
|
||||
# 2) Iterate RequestOutputs.
|
||||
async for request_output in generator:
|
||||
# 3) Convert RequestOutput --> PDResponse.
|
||||
response = PDGenerationResponse.from_request_output(request_output)
|
||||
|
||||
# 4) Serialize and send to PDConroller.
|
||||
response_bytes = self.encoder.encode(response)
|
||||
msg = (PDResponseType.GENERATION, response_bytes)
|
||||
await self.to_controller.send_multipart(msg, copy=False)
|
57
vllm/disaggregated/protocol.py
Normal file
57
vllm/disaggregated/protocol.py
Normal file
@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
# NOTE FOR DEVELOPERS:
|
||||
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
||||
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
||||
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
||||
|
||||
|
||||
class PDRequestType:
|
||||
GENERATION = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
|
||||
|
||||
class PDGenerationRequest(msgspec.Struct):
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
sampling_params: SamplingParams
|
||||
# TODO: support multimodal inputs.
|
||||
|
||||
|
||||
class PDAbortRequest(msgspec.Struct):
|
||||
request_id: str
|
||||
|
||||
|
||||
class PDResponseType:
|
||||
GENERATION = b'\x00'
|
||||
FAILURE = b'\x01'
|
||||
|
||||
|
||||
class PDGenerationResponse(msgspec.Struct):
|
||||
request_id: str
|
||||
text: str
|
||||
token_ids: list[int]
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[str] = None
|
||||
# TODO: support full protocol.
|
||||
logprobs = None
|
||||
|
||||
@classmethod
|
||||
def from_request_output(
|
||||
self, request_output: RequestOutput) -> "PDGenerationResponse":
|
||||
assert len(request_output.outputs) == 1, "Only support N=1 right now."
|
||||
out = request_output.outputs[0]
|
||||
return PDGenerationResponse(
|
||||
request_id=request_output.request_id,
|
||||
text=out.text,
|
||||
token_ids=out.token_ids,
|
||||
finish_reason=out.finish_reason,
|
||||
stop_reason=out.stop_reason,
|
||||
)
|
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
136
vllm/entrypoints/disaggregated/api_server.py
Normal file
136
vllm/entrypoints/disaggregated/api_server.py
Normal file
@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Toy API Server for prototyping.
|
||||
|
||||
Once the PDController is more mature and we clean up
|
||||
the OpenAI Server at bit, we can put the PDController
|
||||
directly inside and launch with vllm serve.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.disaggregated.pd_controller import PDController
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
|
||||
models_ = await handler.show_available_models()
|
||||
return JSONResponse(content=models_.model_dump())
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion # noqa: E501
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def controller_ctx(prefill_addr: str, decode_addr: str,
|
||||
controller_addr: str,
|
||||
model_name: str) -> AsyncIterator[PDController]:
|
||||
c = PDController(prefill_addr, decode_addr, controller_addr, model_name)
|
||||
yield c
|
||||
c.shutdown()
|
||||
|
||||
|
||||
async def main(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Disaggregated Connector Start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
|
||||
# Avoid dropping requests under high concurrency.
|
||||
set_ulimit()
|
||||
|
||||
# IPC Paths.
|
||||
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
decode_addr = f"ipc://{args.decode_addr}"
|
||||
controller_addr = f"ipc://{args.controller_addr}"
|
||||
|
||||
# Start Engine.
|
||||
async with controller_ctx(prefill_addr=prefill_addr,
|
||||
decode_addr=decode_addr,
|
||||
controller_addr=controller_addr,
|
||||
model_name=args.model) as engine_client:
|
||||
|
||||
# Initialize App State.
|
||||
model_config = await engine_client.get_model_config()
|
||||
app.state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
BaseModelPath(name=args.served_model_name or args.model,
|
||||
model_path=args.model)
|
||||
],
|
||||
)
|
||||
app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=app.state.openai_serving_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# Run Server.
|
||||
config = uvicorn.Config(app, host=args.host, port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible P/D Server.")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="The host of the HTTP server.")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="The port of the HTTP server.")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the model.")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The served name of the model.")
|
||||
parser.add_argument("--controller-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc controller address")
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
args = parser.parse_args()
|
||||
uvloop.run(main(args))
|
48
vllm/entrypoints/disaggregated/worker.py
Normal file
48
vllm/entrypoints/disaggregated/worker.py
Normal file
@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import uvloop
|
||||
|
||||
from vllm.disaggregated.pd_worker import PDWorker
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.disaggregated.worker")
|
||||
|
||||
|
||||
async def run(args, engine: EngineClient):
|
||||
try:
|
||||
worker = PDWorker(engine=engine,
|
||||
worker_addr=args.worker_addr,
|
||||
controller_addr=args.controller_addr)
|
||||
await worker.run_busy_loop()
|
||||
finally:
|
||||
worker.shutdown()
|
||||
|
||||
|
||||
async def main(args) -> None:
|
||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
async with build_async_engine_client(args) as engine:
|
||||
await run(args, engine)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument('--controller-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the controller.')
|
||||
parser.add_argument('--worker-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the worker.')
|
||||
parser.add_argument('--disable-frontend-multiprocessing',
|
||||
action="store_true",
|
||||
help='Disable MQLLMEngine for AsyncLLMEngine.')
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
uvloop.run(main(parser.parse_args()))
|
Reference in New Issue
Block a user