Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-07-20 02:15:19 +00:00
parent 14f13ed690
commit b90d33163c
6 changed files with 43 additions and 16 deletions

View File

@ -1091,13 +1091,15 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
data_parallel_external_lb = self.data_parallel_rank is not None
if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
elif self.data_parallel_size_local is not None:
# data_parallel_external_lb = self.data_parallel_rank is not None
# if data_parallel_external_lb:
# assert self.data_parallel_size_local in (1, None), (
# "data_parallel_size_local must be 1 when data_parallel_rank "
# "is set")
# data_parallel_size_local = 1
# elif self.data_parallel_size_local is not None:
data_parallel_external_lb = False
if self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
else:
# Local DP size defaults to global DP size if not set.

View File

@ -45,11 +45,11 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.data_parallel_start_rank:
raise ValueError(
"data_parallel_start_rank is only applicable "
"in headless mode. "
"Add --headless flag to enable headless mode.")
# if args.data_parallel_start_rank:
# raise ValueError(
# "data_parallel_start_rank is only applicable "
# "in headless mode. "
# "Add --headless flag to enable headless mode.")
if args.api_server_count > 1:
run_multi_api_server(args)
else:

View File

@ -303,7 +303,7 @@ def download_weights_from_hf(
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
# tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=local_only,
)

View File

@ -411,10 +411,12 @@ class EngineCoreProc(EngineCore):
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
logger.info("======= HANDSHAKING:")
with self._perform_handshakes(handshake_address, identity,
local_client, vllm_config,
client_handshake_address) as addresses:
self.client_count = len(addresses.outputs)
logger.info(f"{addresses.outputs=}")
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
@ -482,16 +484,21 @@ class EngineCoreProc(EngineCore):
"""
input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None
logger.info(f"HS: {handshake_address=}, {is_local=}")
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config,
vllm_config.parallel_config)
logger.info(f"DONE HS: {handshake=}")
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
logger.info(f"HS: {client_handshake_address=}, {local_client=}")
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client,
vllm_config)
logger.info(f"DONE HS: {local_handshake=}")
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
@ -517,6 +524,8 @@ class EngineCoreProc(EngineCore):
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
logger.info(f"calling startup_handshake: {handshake_address=}")
logger.info(f"calling startup_handshake: {local_client=}")
addresses = self.startup_handshake(handshake_socket, local_client,
parallel_config_to_update)
yield addresses

View File

@ -405,12 +405,15 @@ class MPClient(EngineCoreClient):
"stats_update_address")
else:
# Engines are managed by this client.
print(f"{vllm_config.parallel_config=}")
with launch_core_engines(vllm_config, executor_class,
log_stats) as (engine_manager,
coordinator,
addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
print("========================================")
print(f"{vllm_config.parallel_config=}")
(input_address, ) = addresses.inputs
(output_address, ) = addresses.outputs

View File

@ -555,6 +555,8 @@ def launch_core_engines(
# sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
# HACK: handle case with one pod per node.
client_local_only = True
# Set up input and output addresses.
addresses = EngineZmqAddresses(
@ -601,11 +603,17 @@ def launch_core_engines(
if offline_mode or (external_dp_lb and dp_rank > 0):
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else:
elif dp_rank == 0:
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
else:
# Just handshake with local engines.
engines_to_handshake = [
CoreEngine(index=i, local=True) for i in
range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@ -616,7 +624,8 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0:
# if external_dp_lb and dp_rank > 0:
if dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
@ -624,15 +633,18 @@ def launch_core_engines(
local_handshake_address = handshake_address
client_handshake_address = None
print(f"{local_handshake_address=}")
with zmq_socket_ctx(local_handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
from vllm.v1.engine.core import EngineCoreProc
print(f"{client_handshake_address=}")
print(f"{handshake_address=}")
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
# both be 0. << todo: update
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
@ -650,6 +662,7 @@ def launch_core_engines(
yield local_engine_manager, coordinator, addresses
# Now wait for engines to start.
print(f"{engines_to_handshake=}")
wait_for_engine_startup(
handshake_socket,
addresses,