mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (#15906)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -2189,6 +2189,8 @@ def make_zmq_socket(
|
||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: Optional[bool] = None,
|
||||
identity: Optional[bytes] = None,
|
||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
@ -2207,16 +2209,24 @@ def make_zmq_socket(
|
||||
else:
|
||||
buf_size = -1 # Use system default buffer size
|
||||
|
||||
if socket_type == zmq.constants.PULL:
|
||||
socket.setsockopt(zmq.constants.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
|
||||
if bind is None:
|
||||
bind = socket_type != zmq.PUSH
|
||||
|
||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
|
||||
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
if bind:
|
||||
socket.bind(path)
|
||||
elif socket_type == zmq.constants.PUSH:
|
||||
socket.setsockopt(zmq.constants.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
|
||||
socket.connect(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {socket_type}")
|
||||
socket.connect(path)
|
||||
|
||||
return socket
|
||||
|
||||
@ -2225,14 +2235,19 @@ def make_zmq_socket(
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: Optional[bool] = None,
|
||||
linger: int = 0,
|
||||
identity: Optional[bytes] = None,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, socket_type)
|
||||
|
||||
yield make_zmq_socket(ctx,
|
||||
path,
|
||||
socket_type,
|
||||
bind=bind,
|
||||
identity=identity)
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
|
@ -318,6 +318,11 @@ class EngineCoreProc(EngineCore):
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
self.global_unfinished_reqs = False
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
@ -327,22 +332,16 @@ class EngineCoreProc(EngineCore):
|
||||
Any]] = queue.Queue()
|
||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
args=(input_path, engine_index),
|
||||
daemon=True).start()
|
||||
threading.Thread(target=self.process_output_socket,
|
||||
args=(output_path, engine_index),
|
||||
daemon=True).start()
|
||||
|
||||
self.global_unfinished_reqs = False
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
ready_pipe,
|
||||
**kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
@ -377,9 +376,6 @@ class EngineCoreProc(EngineCore):
|
||||
else:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
ready_pipe.send({"status": "READY"})
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@ -476,14 +472,22 @@ class EngineCoreProc(EngineCore):
|
||||
and not isinstance(v, p.annotation) else v
|
||||
for v, p in zip(args, arg_types))
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
def process_input_socket(self, input_path: str, engine_index: int):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
|
||||
with zmq_socket_ctx(input_path,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False) as socket:
|
||||
|
||||
# Send ready message to front-end once input socket is connected.
|
||||
socket.send(b'READY')
|
||||
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
|
@ -8,7 +8,7 @@ import threading
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from collections.abc import Awaitable
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
@ -35,6 +35,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@ -261,15 +263,13 @@ class CoreEngine:
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
ctx: Union[zmq.Context, zmq.asyncio.Context],
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
index: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
# Paths and sockets for IPC.
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
self.input_socket = make_zmq_socket(ctx, input_path,
|
||||
zmq.constants.PUSH)
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(length=2, byteorder="little")
|
||||
try:
|
||||
# Start EngineCore in background process.
|
||||
self.proc_handle = BackgroundProcHandle(
|
||||
@ -291,14 +291,9 @@ class CoreEngine:
|
||||
# Ensure socket is closed if process fails to start.
|
||||
self.close()
|
||||
|
||||
def send_multipart(self, msg_parts: Sequence):
|
||||
return self.input_socket.send_multipart(msg_parts, copy=False)
|
||||
|
||||
def close(self):
|
||||
if proc_handle := getattr(self, "proc_handle", None):
|
||||
proc_handle.shutdown()
|
||||
if socket := getattr(self, "input_socket", None):
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -309,6 +304,7 @@ class BackgroundResources:
|
||||
ctx: Union[zmq.Context]
|
||||
core_engines: list[CoreEngine] = field(default_factory=list)
|
||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
shutdown_path: Optional[str] = None
|
||||
|
||||
def __call__(self):
|
||||
@ -321,6 +317,8 @@ class BackgroundResources:
|
||||
# aren't explicitly closed first.
|
||||
if self.output_socket is not None:
|
||||
self.output_socket.close(linger=0)
|
||||
if self.input_socket is not None:
|
||||
self.input_socket.close(linger=0)
|
||||
if self.shutdown_path is not None:
|
||||
# We must ensure that the sync output socket is
|
||||
# closed cleanly in its own thread.
|
||||
@ -387,21 +385,51 @@ class MPClient(EngineCoreClient):
|
||||
|
||||
# Paths and sockets for IPC.
|
||||
self.output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
self.input_socket = make_zmq_socket(self.ctx,
|
||||
input_path,
|
||||
zmq.ROUTER,
|
||||
bind=True)
|
||||
self.resources.input_socket = self.input_socket
|
||||
|
||||
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
||||
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
|
||||
index, local_dp_rank)
|
||||
vllm_config, executor_class, log_stats, input_path, self.
|
||||
output_path, index, local_dp_rank)
|
||||
|
||||
# Start engine core process(es).
|
||||
self._init_core_engines(vllm_config, new_core_engine,
|
||||
self.resources.core_engines)
|
||||
|
||||
# Wait for engine core process(es) to start.
|
||||
for engine in self.resources.core_engines:
|
||||
engine.proc_handle.wait_for_startup()
|
||||
self._wait_for_engine_startup()
|
||||
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
def _wait_for_engine_startup(self):
|
||||
# Get a sync handle to the socket which can be sync or async.
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
identities = set(eng.index for eng in self.resources.core_engines)
|
||||
while identities:
|
||||
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
|
||||
logger.info("Waiting for %d core engine proc(s) to start: %s",
|
||||
len(identities), identities)
|
||||
eng_id_bytes, msg = sync_input_socket.recv_multipart()
|
||||
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
|
||||
if eng_id not in identities:
|
||||
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
|
||||
if msg != b'READY':
|
||||
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
|
||||
logger.info("Core engine process %d ready.", eng_id)
|
||||
identities.discard(eng_id)
|
||||
|
||||
# Double check that the process are running.
|
||||
for engine in self.resources.core_engines:
|
||||
proc = engine.proc_handle.proc
|
||||
if proc.exitcode is not None:
|
||||
raise RuntimeError(f"Engine proc {proc.name} not running")
|
||||
|
||||
def _init_core_engines(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -494,9 +522,10 @@ class SyncMPClient(MPClient):
|
||||
return self.outputs_queue.get()
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||
# (RequestType, SerializedRequest)
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
self.core_engine.send_multipart(msg)
|
||||
# (Identity, RequestType, SerializedRequest)
|
||||
msg = (self.core_engine.identity, request_type.value,
|
||||
self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
def call_utility(self, method: str, *args) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
@ -625,30 +654,34 @@ class AsyncMPClient(MPClient):
|
||||
assert self.outputs_queue is not None
|
||||
return await self.outputs_queue.get()
|
||||
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
await self.core_engine.send_multipart(
|
||||
(request_type.value, self.encoder.encode(request)))
|
||||
def _send_input(self,
|
||||
request_type: EngineCoreRequestType,
|
||||
request: Any,
|
||||
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
|
||||
if engine is None:
|
||||
engine = self.core_engine
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
message = (request_type.value, self.encoder.encode(request))
|
||||
return self._send_input_message(message, engine)
|
||||
|
||||
def _send_input_message(self, message: tuple[bytes, bytes],
|
||||
engine: CoreEngine) -> Awaitable[None]:
|
||||
message = (engine.identity, ) + message # type: ignore[assignment]
|
||||
return self.input_socket.send_multipart(message, copy=False)
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
return await self._call_utility_async(method,
|
||||
*args,
|
||||
engine=self.core_engine)
|
||||
|
||||
async def _call_utility_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
engine: CoreEngine,
|
||||
) -> Any:
|
||||
async def _call_utility_async(self, method: str, *args,
|
||||
engine: CoreEngine) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
message = (EngineCoreRequestType.UTILITY.value,
|
||||
self.encoder.encode((call_id, method, args)))
|
||||
await engine.send_multipart(message)
|
||||
await self._send_input_message(message, engine)
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
|
||||
@ -657,6 +690,7 @@ class AsyncMPClient(MPClient):
|
||||
# tokenized.
|
||||
request.prompt = None
|
||||
await self._send_input(EngineCoreRequestType.ADD, request)
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if len(request_ids) > 0:
|
||||
@ -761,15 +795,15 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
chosen_engine.num_reqs_in_flight += 1
|
||||
if self.num_engines_running >= len(self.core_engines):
|
||||
await chosen_engine.send_multipart(msg)
|
||||
await self._send_input_message(msg, chosen_engine)
|
||||
else:
|
||||
# Send request to chosen engine and dp start loop
|
||||
# control message to all other engines.
|
||||
self.num_engines_running += len(self.core_engines)
|
||||
await asyncio.gather(*[
|
||||
engine.send_multipart(msg if engine is
|
||||
chosen_engine else self.start_dp_msg)
|
||||
for engine in self.core_engines
|
||||
self._send_input_message(
|
||||
msg if engine is chosen_engine else self.start_dp_msg,
|
||||
engine) for engine in self.core_engines
|
||||
])
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
@ -794,7 +828,7 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# sure to start the other engines:
|
||||
self.num_engines_running = len(self.core_engines)
|
||||
coros = [
|
||||
engine.send_multipart(self.start_dp_msg)
|
||||
self._send_input_message(self.start_dp_msg, engine)
|
||||
for engine in self.core_engines
|
||||
if not engine.num_reqs_in_flight
|
||||
]
|
||||
@ -820,5 +854,5 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
|
||||
async def _abort_requests(self, request_ids: list[str],
|
||||
engine: CoreEngine) -> None:
|
||||
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
|
||||
self.encoder.encode(request_ids)))
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
||||
engine)
|
||||
|
@ -105,12 +105,9 @@ class BackgroundProcHandle:
|
||||
process_kwargs: dict[Any, Any],
|
||||
):
|
||||
context = get_mp_context()
|
||||
self.reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
assert ("ready_pipe" not in process_kwargs
|
||||
and "input_path" not in process_kwargs
|
||||
assert ("input_path" not in process_kwargs
|
||||
and "output_path" not in process_kwargs)
|
||||
process_kwargs["ready_pipe"] = writer
|
||||
process_kwargs["input_path"] = input_path
|
||||
process_kwargs["output_path"] = output_path
|
||||
|
||||
@ -122,12 +119,6 @@ class BackgroundProcHandle:
|
||||
input_path, output_path)
|
||||
self.proc.start()
|
||||
|
||||
def wait_for_startup(self):
|
||||
# Wait for startup.
|
||||
if self.reader.recv()["status"] != "READY":
|
||||
raise RuntimeError(f"{self.proc.name} initialization failed. "
|
||||
"See root cause above.")
|
||||
|
||||
def shutdown(self):
|
||||
self._finalizer()
|
||||
|
||||
|
Reference in New Issue
Block a user