mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
@ -430,17 +430,22 @@ class MPClient(EngineCoreClient):
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
dp_local_size = parallel_config.data_parallel_size_local
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
offline_mode = parallel_config.data_parallel_rank_local is not None
|
||||
engine_ranks = ([dp_rank] if (offline_mode or external_dp_lb)
|
||||
else range(dp_rank, dp_rank + dp_local_size))
|
||||
|
||||
# If External DPLB, Client manages local EngineCores.
|
||||
# If Internal DPLB, Client manages local+remote EngineCores.
|
||||
num_ranks = (dp_local_size
|
||||
if parallel_config.data_parallel_external_lb else
|
||||
dp_size)
|
||||
self.engine_ranks_managed = ([dp_rank] if offline_mode else range(
|
||||
dp_rank, dp_rank + num_ranks))
|
||||
assert parallel_config.data_parallel_size_local <= len(
|
||||
engine_ranks)
|
||||
self.engine_ranks_managed)
|
||||
|
||||
# ZMQ identity of each engine that this client will talk to.
|
||||
self.core_engines: list[EngineIdentity] = [
|
||||
index.to_bytes(2, "little") for index in engine_ranks
|
||||
rank.to_bytes(2, "little")
|
||||
for rank in self.engine_ranks_managed
|
||||
]
|
||||
|
||||
# Wait for ready messages from each engine on the input socket.
|
||||
@ -895,8 +900,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
return
|
||||
|
||||
assert self.stats_update_address is not None
|
||||
dp_start_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
dp_end_rank = dp_start_rank + self.vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with make_zmq_socket(self.ctx, self.stats_update_address,
|
||||
@ -961,9 +964,9 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
# NOTE: counts includes num running for all global
|
||||
# EngineCores, so need to slide for the local ones.
|
||||
self.lb_engines = counts[dp_start_rank:dp_end_rank]
|
||||
# NOTE: counts include all global Cores. Slice
|
||||
# to get get the Core's managed by this client.
|
||||
self.lb_engines = counts[self.engine_ranks_managed]
|
||||
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
|
Reference in New Issue
Block a user