mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
@ -44,18 +44,27 @@ wait_for_disagg_server() {
|
||||
|
||||
|
||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
CONNECTOR_ADDR=connectoripc
|
||||
PREFILL_WORKER_ADDR=prefillipc
|
||||
DECODE_WORKER_ADDR=prefillipc
|
||||
PORT=8000
|
||||
|
||||
# prefilling instance, which is the KV producer
|
||||
CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--zmq-server-addr testipc0 \
|
||||
CUDA_VISIBLE_DEVICES=0 python3 -m vllm.entrypoints.disaggregated.worker \
|
||||
--model $MODEL \
|
||||
--connector-addr $CONNECTOR_ADDR \
|
||||
--worker-addr $PREFILL_WORKER_ADDR \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
|
||||
|
||||
# decoding instance, which is the KV consumer
|
||||
CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--zmq-server-addr testipc1 \
|
||||
CUDA_VISIBLE_DEVICES=1 python3 -m vllm.entrypoints.disaggregated.worker \
|
||||
--model $MODEL \
|
||||
--connector-addr $CONNECTOR_ADDR \
|
||||
--worker-addr $DECODE_WORKER_ADDR \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
@ -63,16 +72,17 @@ CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
|
||||
# launch a proxy server that opens the service at port 8000
|
||||
# the workflow of this proxy:
|
||||
# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens
|
||||
# to 1
|
||||
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
|
||||
# instance (via zmq addr testipc1)
|
||||
vllm connect --port 8000 \
|
||||
--prefill-addr testipc0 \
|
||||
--decode-addr testipc1 &
|
||||
# - Send req to prefill instance, wait until complete.
|
||||
# - Send req to decode instance, streaming tokens.
|
||||
python3 -m vllm.entrypoints.disaggregated.connector \
|
||||
--port $PORT \
|
||||
--model $MODEL \
|
||||
--connector-addr $CONNECTOR_ADDR \
|
||||
--prefill-addr $PREFILL_WORKER_ADDR \
|
||||
--decode-addr $DECODE_WORKER_ADDR
|
||||
|
||||
# wait until prefill, decode instances and proxy are ready
|
||||
wait_for_server 8000
|
||||
wait_for_server $PORT
|
||||
wait_for_disagg_server vllm_disagg_prefill.log
|
||||
wait_for_disagg_server vllm_disagg_decode.log
|
||||
|
||||
|
@ -1,126 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.disaggregated.pd_engine import PDEngine
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionResponse, ErrorResponse)
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
|
||||
models_ = await handler.show_available_models()
|
||||
return JSONResponse(content=models_.model_dump())
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
@asynccontextmanager
|
||||
async def pd_engine_client_ctx_manager(
|
||||
model_name: str,
|
||||
prefill_addr: str,
|
||||
decode_addr: str,
|
||||
connector_addr: str) -> AsyncIterator[PDEngine]:
|
||||
engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr)
|
||||
yield engine
|
||||
engine.shutdown()
|
||||
|
||||
async def main(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
|
||||
# Avoid dropping requests under high concurrency.
|
||||
set_ulimit()
|
||||
|
||||
# IPC Paths.
|
||||
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
|
||||
# are not using pickle to avoid RCE security flaw.
|
||||
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
decode_addr = f"ipc://{args.decode_addr}"
|
||||
connector_addr = f"ipc://{args.connector_addr}"
|
||||
|
||||
# Start Engine.
|
||||
with pd_engine_client_ctx_manager(
|
||||
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
|
||||
|
||||
# Initialize App State.
|
||||
model_config = await engine_client.get_model_config()
|
||||
app.state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=[BaseModelPath(
|
||||
name=args.served_model_name or args.model,
|
||||
model_path=args.model)
|
||||
],
|
||||
)
|
||||
app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=app.state.openai_serving_models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# Run Server.
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible P/D Server.")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="The host of the HTTP server.")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="The port of the HTTP server.")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the model.")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The served name of the model.")
|
||||
parser.add_argument("--connector-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc connector address")
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
args = parser.parse_args()
|
||||
uvloop.run(main(args))
|
@ -1,296 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import msgspec
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.utils import Device, make_zmq_socket
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 32000
|
||||
|
||||
class PDEngine:
|
||||
"""
|
||||
PDEngine:
|
||||
Equiavlent of AsyncLLM for P/D. Assumes there is
|
||||
a Prefill and Decode service already running.
|
||||
|
||||
* TODO: actually handle errors and failure.
|
||||
* TODO: support more than just text input.
|
||||
* TODO: move under vllm/v1/engine one past prototype.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_addr: str,
|
||||
decode_addr: str,
|
||||
connector_addr: str,
|
||||
model_name: str
|
||||
):
|
||||
# Request queues.
|
||||
self.queues: Dict[str, asyncio.Queue] = {}
|
||||
|
||||
# Serialization encoder.
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
|
||||
# ZMQ communication..
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
self.to_decode = make_zmq_socket(
|
||||
self.ctx, f"{decode_addr}", zmq.constants.PUSH)
|
||||
self.to_prefill = make_zmq_socket(
|
||||
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
|
||||
self.connector_addr = connector_addr
|
||||
self.decode_addr = decode_addr
|
||||
self.prefill_addr = prefill_addr
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
self.log_running: Optional[asyncio.Task] = None
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor EngineClient to avoid needing this.
|
||||
self.model_config = ModelConfig(
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="auto",
|
||||
seed=42
|
||||
)
|
||||
|
||||
# Dummy: needed for EngineClient Protocol.
|
||||
# TODO: refactor EngineClient to avoid needing this.
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=False,
|
||||
max_num_seqs=1024,
|
||||
max_loras=0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision,
|
||||
truncation_side=self.model_config.truncation_side)
|
||||
self.tokenizer = TokenizerGroup(**init_kwargs)
|
||||
|
||||
def shutdown(self):
|
||||
if (ctx := self.ctx) is not None:
|
||||
ctx.destroy(linger=0)
|
||||
if (task := self.log_running) is not None:
|
||||
task.cancel()
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
ipc_paths = [
|
||||
self.connector_addr, self.decode_addr, self.prefill_addr
|
||||
]
|
||||
for path in ipc_paths:
|
||||
socket_path = path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
async def _run_log_running(self):
|
||||
logger.info("Running requests: %d", len(self.queues))
|
||||
await asyncio.sleep(10.)
|
||||
|
||||
async def _run_output_handler(self, socket: zmq.asyncio.Socket):
|
||||
"""
|
||||
Pull responses from Decode + Prefill engines and
|
||||
distribute back to the generate() tasks.
|
||||
"""
|
||||
decoder = msgspec.msgpack.Decoder(PDResponse)
|
||||
|
||||
socket: Optional[zmq.asyncio.Socket] = None
|
||||
try:
|
||||
socket = make_zmq_socket(
|
||||
self.ctx, self.connector_addr, zmq.constants.PULL)
|
||||
|
||||
while True:
|
||||
reponse_bytes = await socket.recv().buffer
|
||||
response = decoder.decode(reponse_bytes)
|
||||
self.queues[response.request_id].put_nowait(response)
|
||||
except:
|
||||
# TODO: actually handle failure and shutdown.
|
||||
raise
|
||||
finally:
|
||||
if socket is not None:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _prefill(self,
|
||||
request: PDRequest,
|
||||
q: asyncio.Queue[PDResponse]) -> PDResponse:
|
||||
# Send request to the prefill instance.
|
||||
req_bytes = self.encoder(request)
|
||||
await self.to_prefill.send(req_bytes, copy=False)
|
||||
|
||||
# Wait for the prefill to be done.
|
||||
response = await q.get()
|
||||
assert response.request_id == request.request_id
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Prefill Request.")
|
||||
|
||||
return response
|
||||
|
||||
async def _decode(self,
|
||||
request: PDRequest,
|
||||
q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]:
|
||||
# Send request to the decode instance.
|
||||
req_bytes = self.encoder(request)
|
||||
await self.to_decode.send(req_bytes, copy=False)
|
||||
|
||||
# Iterate response queue and yield each response to caller..
|
||||
finished = False
|
||||
while not finished:
|
||||
response = await q.get()
|
||||
if not response.success:
|
||||
# TODO: actual error handling and shutdown.
|
||||
raise Exception("Failed Decode Request.")
|
||||
finished = response.finish_reason is not None
|
||||
yield response
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[PDResponse]:
|
||||
# Start loops on first request.
|
||||
if self.output_handler is None:
|
||||
self.output_handler = asyncio.create_task(self._run_output_handler())
|
||||
self.log_running = asyncio.create_task(self._run_log_running())
|
||||
|
||||
# TODO: Expand to support the full matrix.
|
||||
if not isinstance(prompt, str):
|
||||
raise NotImplementedError(
|
||||
"We currently only support text for P/D!")
|
||||
if lora_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport LoRA for P/D!")
|
||||
if trace_headers is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport tracing for P/D!")
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"We currently do not suppport prompt adapter for P/D!")
|
||||
if priority != 0:
|
||||
raise NotImplementedError(
|
||||
"We currently do not support priority for P/D!")
|
||||
if request_id in self.queues:
|
||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||
|
||||
# Queue to gather output from output_handler.
|
||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||
self.queues[request_id] = q
|
||||
|
||||
# (1) Perform the Prefill.
|
||||
original_max_tokens = sampling_params.max_tokens
|
||||
request = PDRequest(request_id, prompt, sampling_params)
|
||||
request.sampling_params.max_tokens = 1
|
||||
response = await self._prefill(request, q)
|
||||
yield response
|
||||
|
||||
# (2) Perform the Decodes.
|
||||
request.sampling_params.max_tokens = original_max_tokens
|
||||
async for response in self._decode(request, q):
|
||||
yield response
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if lora_request is not None:
|
||||
raise NotImplementedError(
|
||||
"LoRA is not yet supported in the PDEngine.")
|
||||
return self.tokenizer.get_lora_tokenizer(None)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def check_health(self) -> None:
|
||||
pass
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def wake_up(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
False
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
raise NotImplementedError
|
@ -1,126 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import msgpack
|
||||
import signal
|
||||
import uvloop
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
async def handle_request(
|
||||
request: PDRequest,
|
||||
engine: EngineClient,
|
||||
socket: zmq.asyncio.Socket,
|
||||
encoder: msgpack.Encoder,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
try:
|
||||
# 1) Generate RequestOutputs.
|
||||
async for request_output in engine.generate(
|
||||
prompt=request.prompt_token_ids,
|
||||
sampling_params=request.sampling_params,
|
||||
request_id=request_id):
|
||||
|
||||
assert len(request_output.outputs) == 0, "Only support N=1 right now."
|
||||
out = request_output.outputs[0]
|
||||
|
||||
# 2) Convert RequestOutput --> PDResponse.
|
||||
response = PDResponse(
|
||||
request_id=request_id,
|
||||
success=True,
|
||||
text=out.text,
|
||||
token_ids=out.token_ids,
|
||||
finish_reasons=out.finish_reason,
|
||||
stop_reason=out.stop_reason,
|
||||
)
|
||||
response_bytes = encoder(response)
|
||||
|
||||
# 3) Send to Connector.
|
||||
await socket.send(response_bytes, copy=False)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: actual error handling.
|
||||
logger.error("Exception in Worker Routine: %s request_id: %s", e,
|
||||
request_id)
|
||||
response = PDResponse(request_id=request_id, success=False)
|
||||
response_bytes = encoder(response)
|
||||
await socket.send(response, copy=False)
|
||||
|
||||
async def run_server(args, engine: EngineClient):
|
||||
"""Get Requests and Handle Them."""
|
||||
running_requests: set[asyncio.Task] = set()
|
||||
decoder = msgpack.Decoder(PDRequest)
|
||||
encoder = msgpack.Encoder()
|
||||
|
||||
ctx: Optional[zmq.asyncio.Context] = None
|
||||
try:
|
||||
# IPC Setup.
|
||||
ctx = zmq.asyncio.Context()
|
||||
from_connector = make_zmq_socket(
|
||||
ctx, f"ipc://{args.server_addr}", zmq.constants.PULL)
|
||||
to_connector = make_zmq_socket(
|
||||
ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH)
|
||||
|
||||
# Main Loop.
|
||||
while True:
|
||||
# 1) Get request from the Connector.
|
||||
pd_request_bytes = await from_connector.recv().buffer
|
||||
pd_request = decoder(pd_request_bytes)
|
||||
|
||||
# 2) Launch a coroutine to handle the request.
|
||||
task = asyncio.create_task(handle_request(
|
||||
pd_request, engine, to_connector, encoder))
|
||||
running_requests.add(task)
|
||||
task.add_done_callback(running_requests.discard)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Worker server loop interrupted.")
|
||||
|
||||
finally:
|
||||
for task in running_requests:
|
||||
task.cancel()
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
async def main(args) -> None:
|
||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Workaround to avoid footguns where uvicorn drops requests
|
||||
# with too many concurrent requests active due to ulimit.
|
||||
set_ulimit()
|
||||
|
||||
# Interrupt on sigterm during initialization.
|
||||
def signal_handler(*_) -> None:
|
||||
raise KeyboardInterrupt("terminated")
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine:
|
||||
await run_server(args, engine)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument('--connector-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the connector.')
|
||||
parser.add_argument('--worker-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the worker.')
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
uvloop.run(main(args))
|
Reference in New Issue
Block a user