Compare commits

...

4 Commits

Author SHA1 Message Date
f1c9ef3afd Merge remote-tracking branch 'nm/lwilkinson/fix-flashmla-full-cudagraph' into wide_ep_working_branch 2025-07-27 21:22:09 +00:00
d80a82f961 fix dp plus full cuda-graph
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-07-27 21:06:56 +00:00
ec1250421a [BugFix] Harden coordinator startup
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-07-25 15:36:38 +01:00
8177e2f02f [BugFix] Improve internal DP load balancing
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-07-25 15:34:15 +01:00
6 changed files with 171 additions and 95 deletions

View File

@ -2827,6 +2827,9 @@ def make_zmq_socket(
if linger is not None:
socket.setsockopt(zmq.LINGER, linger)
if socket_type == zmq.XPUB:
socket.setsockopt(zmq.XPUB_VERBOSE, True)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)

View File

@ -67,6 +67,20 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph:
self.cg_buf_tile_scheduler_metadata = torch.empty(
(num_sms, 8), # TileSchedulerMetaDataSize == 8
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
@ -77,28 +91,24 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
if self.compilation_config.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import time
import weakref
@ -66,18 +67,14 @@ class DPCoordinator:
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb and not hybrid_lb, host=host)
local_only=local_only, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
@ -87,7 +84,6 @@ class DPCoordinator:
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
},
daemon=True)
self.proc.start()
@ -126,10 +122,6 @@ class DPCoordinatorProc:
self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
@ -156,6 +148,13 @@ class DPCoordinatorProc:
decoder = MsgpackDecoder(EngineCoreOutputs)
current_wave = 0
engines_running = False
stats_changed = False
last_stats_step = -1
last_step_counts: Optional[list[list[int]]] = None
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
@ -173,6 +172,18 @@ class DPCoordinatorProc:
bind=True,
) as publish_back:
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b'\x01':
logger.error(
"DP Coordinator received unexpected message while"
"waiting for engines to subscribe")
return
# Send ready message to engines.
publish_back.send(b"READY")
logger.info("All engine subscriptions received by DP coordinator")
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
@ -180,21 +191,33 @@ class DPCoordinatorProc:
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds.
# changed, or otherwise every 5 seconds.
wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
if stats_changed else 5000)
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0
events = poller.poll(timeout=max(min_timeout, wait_for -
elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
if last_step_counts is not None:
engine_req_counts_list = last_step_counts
last_step_counts = None
else:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False
to_publish = (engine_req_counts_list, current_wave,
engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
wave_state_changed = False
if publish_front in events:
buffer = publish_front.recv()
@ -221,7 +244,7 @@ class DPCoordinatorProc:
# current_wave
# we note that 0 is the wave number for the new
# engine
self.engines_running = False
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count)
@ -237,15 +260,15 @@ class DPCoordinatorProc:
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = decoded
if not self.engines_running:
if wave < self.current_wave:
if not engines_running:
if wave < current_wave:
# If the wave number is stale, ensure the message
# is handled by all the engines.
engine_to_exclude = None
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, current_wave,
engine_to_exclude)
if output_back in events:
@ -263,36 +286,47 @@ class DPCoordinatorProc:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
if stats_changed and stats_step != last_stats_step:
last_step_counts = self._get_engine_counts(
do_copy=True)
elif stats_step < last_stats_step:
logger.warning("Received stats for out-of-order "
"step from engine {eng_index}")
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
self.stats_changed = True
last_stats_step = stats_step
stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False).
if self.current_wave <= wave:
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, new_wave)
self.current_wave = new_wave
self.engines_running = False
self.stats_changed = True
current_wave, new_wave)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
wave > current_wave or
(wave == current_wave and not engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
current_wave = wave
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
if wave_state_changed:
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
@ -305,6 +339,8 @@ class DPCoordinatorProc:
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
if do_copy:
return [copy.copy(e.request_counts) for e in self.engines]
return [e.request_counts for e in self.engines]

View File

@ -445,8 +445,11 @@ class EngineCoreProc(EngineCore):
self.has_coordinator = addresses.coordinator_output is not None
self.frontend_stats_publish_address = (
addresses.frontend_stats_publish_address)
logger.debug("Has DP Coordinator: %s, stats publish address: %s",
self.has_coordinator,
self.frontend_stats_publish_address)
# Only publish request queue stats to coordinator for "internal"
# LB mode.
# and "hybrid" LB modes .
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
@ -456,25 +459,38 @@ class EngineCoreProc(EngineCore):
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
ready_event = threading.Event()
input_thread = threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs,
addresses.coordinator_input,
identity, ready_event),
daemon=True)
input_thread.start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
# Don't complete handshake until DP coordinator ready message is
# received.
while not ready_event.wait(10):
if not input_thread.is_alive():
raise RuntimeError(
"Input socket thread died during startup")
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator")
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshakes(
self,
@ -489,10 +505,10 @@ class EngineCoreProc(EngineCore):
For DP=1 or offline mode, this is with the colocated front-end process.
For DP>1 with internal loadbalancing this is with the shared front-end
For DP>1 with internal load-balancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external or hybrid loadbalancing, two handshakes are
For DP>1 with external or hybrid load-balancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
@ -751,7 +767,7 @@ class EngineCoreProc(EngineCore):
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
identity: bytes, ready_event: threading.Event):
"""Input socket IO thread."""
# Msgpack serialization decoding.
@ -788,9 +804,14 @@ class EngineCoreProc(EngineCore):
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
# Wait for ready message from coordinator.
assert coord_socket.recv() == b"READY"
poller.register(coord_socket, zmq.POLLIN)
ready_event.set()
del ready_event
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
@ -887,7 +908,7 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.step_counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
@ -967,7 +988,7 @@ class DPEngineCoreProc(EngineCoreProc):
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
stats = SchedulerStats(*counts, step_counter=self.step_counter)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
@ -1014,10 +1035,10 @@ class DPEngineCoreProc(EngineCoreProc):
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1
if self.counter != 32:
self.step_counter += 1
if self.step_counter != 32:
return True
self.counter = 0
self.step_counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)

View File

@ -986,7 +986,12 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
self.lb_engines = counts[count_slice]
if counts is not None:
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
#TODO TBD whether to keep this debug log
logger.debug("Received counts: %s (%s)",
sliced_counts, count_slice)
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())
@ -1035,27 +1040,26 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
def get_core_engine_for_request(
self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
current_counts = self.lb_engines
if (eng_index := request.data_parallel_rank) is None:
if not self.lb_engines:
if not current_counts:
return self.core_engine
# TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines)
min_counts = [sys.maxsize, sys.maxsize]
num_engines = len(current_counts)
min_score = sys.maxsize
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.client_index + i) % num_engines
counts = self.lb_engines[idx]
if counts < min_counts:
min_counts = counts
waiting, running = current_counts[idx]
score = waiting * 4 + running
if score < min_score:
min_score = score
eng_index = idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if min_counts[0]:
min_counts[0] += 1
else:
min_counts[1] += 1
# Increment local waiting count for better balancing between stats
# updates from the coordinator (which happen every 100ms).
current_counts[eng_index][0] += 1
chosen_engine = self.core_engines[eng_index]
# Record which engine is chosen for this request, to handle aborts.

View File

@ -33,6 +33,8 @@ class SchedulerStats:
num_running_reqs: int = 0
num_waiting_reqs: int = 0
step_counter: int = 0
kv_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(