mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
@ -430,17 +430,22 @@ class MPClient(EngineCoreClient):
|
|||||||
dp_size = parallel_config.data_parallel_size
|
dp_size = parallel_config.data_parallel_size
|
||||||
dp_rank = parallel_config.data_parallel_rank
|
dp_rank = parallel_config.data_parallel_rank
|
||||||
dp_local_size = parallel_config.data_parallel_size_local
|
dp_local_size = parallel_config.data_parallel_size_local
|
||||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
|
||||||
|
|
||||||
offline_mode = parallel_config.data_parallel_rank_local is not None
|
offline_mode = parallel_config.data_parallel_rank_local is not None
|
||||||
engine_ranks = ([dp_rank] if (offline_mode or external_dp_lb)
|
|
||||||
else range(dp_rank, dp_rank + dp_local_size))
|
# If External DPLB, Client manages local EngineCores.
|
||||||
|
# If Internal DPLB, Client manages local+remote EngineCores.
|
||||||
|
num_ranks = (dp_local_size
|
||||||
|
if parallel_config.data_parallel_external_lb else
|
||||||
|
dp_size)
|
||||||
|
self.engine_ranks_managed = ([dp_rank] if offline_mode else range(
|
||||||
|
dp_rank, dp_rank + num_ranks))
|
||||||
assert parallel_config.data_parallel_size_local <= len(
|
assert parallel_config.data_parallel_size_local <= len(
|
||||||
engine_ranks)
|
self.engine_ranks_managed)
|
||||||
|
|
||||||
# ZMQ identity of each engine that this client will talk to.
|
# ZMQ identity of each engine that this client will talk to.
|
||||||
self.core_engines: list[EngineIdentity] = [
|
self.core_engines: list[EngineIdentity] = [
|
||||||
index.to_bytes(2, "little") for index in engine_ranks
|
rank.to_bytes(2, "little")
|
||||||
|
for rank in self.engine_ranks_managed
|
||||||
]
|
]
|
||||||
|
|
||||||
# Wait for ready messages from each engine on the input socket.
|
# Wait for ready messages from each engine on the input socket.
|
||||||
@ -895,8 +900,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert self.stats_update_address is not None
|
assert self.stats_update_address is not None
|
||||||
dp_start_rank = self.vllm_config.parallel_config.data_parallel_rank
|
|
||||||
dp_end_rank = dp_start_rank + self.vllm_config.parallel_config.data_parallel_size_local
|
|
||||||
|
|
||||||
async def run_engine_stats_update_task():
|
async def run_engine_stats_update_task():
|
||||||
with make_zmq_socket(self.ctx, self.stats_update_address,
|
with make_zmq_socket(self.ctx, self.stats_update_address,
|
||||||
@ -961,9 +964,9 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||||
self.current_wave = wave
|
self.current_wave = wave
|
||||||
self.engines_running = running
|
self.engines_running = running
|
||||||
# NOTE: counts includes num running for all global
|
# NOTE: counts include all global Cores. Slice
|
||||||
# EngineCores, so need to slide for the local ones.
|
# to get get the Core's managed by this client.
|
||||||
self.lb_engines = counts[dp_start_rank:dp_end_rank]
|
self.lb_engines = counts[self.engine_ranks_managed]
|
||||||
|
|
||||||
resources.stats_update_task = asyncio.create_task(
|
resources.stats_update_task = asyncio.create_task(
|
||||||
run_engine_stats_update_task())
|
run_engine_stats_update_task())
|
||||||
|
|||||||
Reference in New Issue
Block a user