[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (#15906)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-04-04 12:56:43 -07:00
committed by GitHub
parent 4dc52e1c53
commit 651cf0fec1
4 changed files with 113 additions and 69 deletions

View File

@ -2189,6 +2189,8 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
@ -2207,16 +2209,24 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
if bind is None:
bind = socket_type != zmq.PUSH
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
if bind:
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
socket.connect(path)
return socket
@ -2225,14 +2235,19 @@ def make_zmq_socket(
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: Optional[bool] = None,
linger: int = 0,
identity: Optional[bytes] = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, socket_type)
yield make_zmq_socket(ctx,
path,
socket_type,
bind=bind,
identity=identity)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")

View File

@ -318,6 +318,11 @@ class EngineCoreProc(EngineCore):
):
super().__init__(vllm_config, executor_class, log_stats)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.global_unfinished_reqs = False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
@ -327,22 +332,16 @@ class EngineCoreProc(EngineCore):
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
args=(input_path, engine_index),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True).start()
self.global_unfinished_reqs = False
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""
@ -377,9 +376,6 @@ class EngineCoreProc(EngineCore):
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop()
except SystemExit:
@ -476,14 +472,22 @@ class EngineCoreProc(EngineCore):
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str):
def process_input_socket(self, input_path: str, engine_index: int):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:
# Send ready message to front-end once input socket is connected.
socket.send(b'READY')
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)

View File

@ -8,7 +8,7 @@ import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from collections.abc import Awaitable
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
@ -35,6 +35,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
STARTUP_POLL_PERIOD_MS = 10000
class EngineCoreClient(ABC):
"""
@ -261,15 +263,13 @@ class CoreEngine:
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
input_path: str,
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
@ -291,14 +291,9 @@ class CoreEngine:
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass
@ -309,6 +304,7 @@ class BackgroundResources:
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
shutdown_path: Optional[str] = None
def __call__(self):
@ -321,6 +317,8 @@ class BackgroundResources:
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
@ -387,21 +385,51 @@ class MPClient(EngineCoreClient):
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
self._wait_for_engine_startup()
self.utility_results: dict[int, AnyFuture] = {}
def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
while identities:
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
logger.info("Waiting for %d core engine proc(s) to start: %s",
len(identities), identities)
eng_id_bytes, msg = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
if msg != b'READY':
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)
# Double check that the process are running.
for engine in self.resources.core_engines:
proc = engine.proc_handle.proc
if proc.exitcode is not None:
raise RuntimeError(f"Engine proc {proc.name} not running")
def _init_core_engines(
self,
vllm_config: VllmConfig,
@ -494,9 +522,10 @@ class SyncMPClient(MPClient):
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.core_engine.send_multipart(msg)
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
@ -625,30 +654,34 @@ class AsyncMPClient(MPClient):
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
if engine is None:
engine = self.core_engine
self._ensure_output_queue_task()
message = (request_type.value, self.encoder.encode(request))
return self._send_input_message(message, engine)
def _send_input_message(self, message: tuple[bytes, bytes],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
return self.input_socket.send_multipart(message, copy=False)
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
async def _call_utility_async(self, method: str, *args,
engine: CoreEngine) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
@ -657,6 +690,7 @@ class AsyncMPClient(MPClient):
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
async def abort_requests_async(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
@ -761,15 +795,15 @@ class DPAsyncMPClient(AsyncMPClient):
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
await self._send_input_message(msg, chosen_engine)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
self._send_input_message(
msg if engine is chosen_engine else self.start_dp_msg,
engine) for engine in self.core_engines
])
self._ensure_output_queue_task()
@ -794,7 +828,7 @@ class DPAsyncMPClient(AsyncMPClient):
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
self._send_input_message(self.start_dp_msg, engine)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
@ -820,5 +854,5 @@ class DPAsyncMPClient(AsyncMPClient):
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)

View File

@ -105,12 +105,9 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
self.reader, writer = context.Pipe(duplex=False)
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
assert ("input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
@ -122,12 +119,6 @@ class BackgroundProcHandle:
input_path, output_path)
self.proc.start()
def wait_for_startup(self):
# Wait for startup.
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")
def shutdown(self):
self._finalizer()