Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-07-20 20:17:54 +00:00
parent d4ab18f19d
commit 2cf8ff64c7

View File

@ -430,17 +430,22 @@ class MPClient(EngineCoreClient):
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
dp_local_size = parallel_config.data_parallel_size_local
external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = ([dp_rank] if (offline_mode or external_dp_lb)
else range(dp_rank, dp_rank + dp_local_size))
# If External DPLB, Client manages local EngineCores.
# If Internal DPLB, Client manages local+remote EngineCores.
num_ranks = (dp_local_size
if parallel_config.data_parallel_external_lb else
dp_size)
self.engine_ranks_managed = ([dp_rank] if offline_mode else range(
dp_rank, dp_rank + num_ranks))
assert parallel_config.data_parallel_size_local <= len(
engine_ranks)
self.engine_ranks_managed)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks
rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
]
# Wait for ready messages from each engine on the input socket.
@ -895,8 +900,6 @@ class DPAsyncMPClient(AsyncMPClient):
return
assert self.stats_update_address is not None
dp_start_rank = self.vllm_config.parallel_config.data_parallel_rank
dp_end_rank = dp_start_rank + self.vllm_config.parallel_config.data_parallel_size_local
async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
@ -961,9 +964,9 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
# NOTE: counts includes num running for all global
# EngineCores, so need to slide for the local ones.
self.lb_engines = counts[dp_start_rank:dp_end_rank]
# NOTE: counts include all global Cores. Slice
# to get get the Core's managed by this client.
self.lb_engines = counts[self.engine_ranks_managed]
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())