Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-07-20 20:17:54 +00:00
parent d4ab18f19d
commit 2cf8ff64c7

View File

@ -430,17 +430,22 @@ class MPClient(EngineCoreClient):
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
dp_local_size = parallel_config.data_parallel_size_local dp_local_size = parallel_config.data_parallel_size_local
external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = ([dp_rank] if (offline_mode or external_dp_lb)
else range(dp_rank, dp_rank + dp_local_size)) # If External DPLB, Client manages local EngineCores.
# If Internal DPLB, Client manages local+remote EngineCores.
num_ranks = (dp_local_size
if parallel_config.data_parallel_external_lb else
dp_size)
self.engine_ranks_managed = ([dp_rank] if offline_mode else range(
dp_rank, dp_rank + num_ranks))
assert parallel_config.data_parallel_size_local <= len( assert parallel_config.data_parallel_size_local <= len(
engine_ranks) self.engine_ranks_managed)
# ZMQ identity of each engine that this client will talk to. # ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [ self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
] ]
# Wait for ready messages from each engine on the input socket. # Wait for ready messages from each engine on the input socket.
@ -895,8 +900,6 @@ class DPAsyncMPClient(AsyncMPClient):
return return
assert self.stats_update_address is not None assert self.stats_update_address is not None
dp_start_rank = self.vllm_config.parallel_config.data_parallel_rank
dp_end_rank = dp_start_rank + self.vllm_config.parallel_config.data_parallel_size_local
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address, with make_zmq_socket(self.ctx, self.stats_update_address,
@ -961,9 +964,9 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf) counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave self.current_wave = wave
self.engines_running = running self.engines_running = running
# NOTE: counts includes num running for all global # NOTE: counts include all global Cores. Slice
# EngineCores, so need to slide for the local ones. # to get get the Core's managed by this client.
self.lb_engines = counts[dp_start_rank:dp_end_rank] self.lb_engines = counts[self.engine_ranks_managed]
resources.stats_update_task = asyncio.create_task( resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task()) run_engine_stats_update_task())