[BugFix] Harden distributed DP startup (#21538)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-01 22:40:45 +01:00
committed by GitHub
parent d84b97a3e3
commit 881e1af43a
3 changed files with 56 additions and 20 deletions

View File

@ -2794,6 +2794,9 @@ def make_zmq_socket(
if linger is not None:
socket.setsockopt(zmq.LINGER, linger)
if socket_type == zmq.XPUB:
socket.setsockopt(zmq.XPUB_VERBOSE, True)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)

View File

@ -172,6 +172,18 @@ class DPCoordinatorProc:
bind=True,
) as publish_back:
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b'\x01':
logger.error(
"DP Coordinator received unexpected message while "
"waiting for engines to subscribe")
return
# Send ready message to engines.
publish_back.send(b"READY")
logger.info("All engine subscriptions received by DP coordinator")
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)

View File

@ -461,8 +461,11 @@ class EngineCoreProc(EngineCore):
self.has_coordinator = addresses.coordinator_output is not None
self.frontend_stats_publish_address = (
addresses.frontend_stats_publish_address)
logger.debug("Has DP Coordinator: %s, stats publish address: %s",
self.has_coordinator,
self.frontend_stats_publish_address)
# Only publish request queue stats to coordinator for "internal"
# LB mode.
# and "hybrid" LB modes .
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
@ -472,25 +475,38 @@ class EngineCoreProc(EngineCore):
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
ready_event = threading.Event()
input_thread = threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs,
addresses.coordinator_input,
identity, ready_event),
daemon=True)
input_thread.start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
# Don't complete handshake until DP coordinator ready message is
# received.
while not ready_event.wait(timeout=10):
if not input_thread.is_alive():
raise RuntimeError(
"Input socket thread died during startup")
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...")
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshakes(
self,
@ -505,10 +521,10 @@ class EngineCoreProc(EngineCore):
For DP=1 or offline mode, this is with the colocated front-end process.
For DP>1 with internal loadbalancing this is with the shared front-end
For DP>1 with internal load-balancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external or hybrid loadbalancing, two handshakes are
For DP>1 with external or hybrid load-balancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
@ -772,7 +788,7 @@ class EngineCoreProc(EngineCore):
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
identity: bytes, ready_event: threading.Event):
"""Input socket IO thread."""
# Msgpack serialization decoding.
@ -809,9 +825,14 @@ class EngineCoreProc(EngineCore):
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
# Wait for ready message from coordinator.
assert coord_socket.recv() == b"READY"
poller.register(coord_socket, zmq.POLLIN)
ready_event.set()
del ready_event
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)