[DP][ray] Support different VLLM_RAY_DP_PACK_STRATEGY (#23849)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-10-09 19:53:43 -07:00
committed by GitHub
parent c6187f55f7
commit 757fa4a4da
2 changed files with 86 additions and 33 deletions

View File

@ -133,6 +133,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: str = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@ -1000,6 +1001,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0"
)
== "1",
# Strategy to pack the data parallel ranks for Ray.
# Available options:
# - "fill":
# for DP master node, allocate exactly data-parallel-size-local DP ranks,
# for non-master nodes, allocate as many DP ranks as can fit;
# - "strict":
# allocate exactly data-parallel-size-local DP ranks to each picked node;
# This environment variable is ignored if data-parallel-backend is not Ray.
"VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv(
"VLLM_RAY_DP_PACK_STRATEGY", "strict"
),
# Whether to use S3 path for model loading in CI via RunAI Streamer
"VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder.

View File

@ -15,6 +15,7 @@ from unittest.mock import patch
import msgspec
import zmq
from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -337,8 +338,8 @@ class CoreEngineActorManager:
logger.info("Creating placement groups for data parallel")
dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
local_engine_count = vllm_config.parallel_config.data_parallel_size_local
dp_size = vllm_config.parallel_config.data_parallel_size
dp_size_local = vllm_config.parallel_config.data_parallel_size_local
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
@ -354,44 +355,84 @@ class CoreEngineActorManager:
dp_master_ip,
)
device_str = current_platform.ray_device_key
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
):
raise ValueError(
"DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "
"to be on the same node, but VLLM_RAY_DP_PACK_STRATEGY=fill "
"does not guarantee that. "
"Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
)
logger.info(
"Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY",
envs.VLLM_RAY_DP_PACK_STRATEGY,
)
strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict"
for node_resources in nodes:
if device_str not in node_resources:
continue
node_ip_keys = [
key
for key in node_resources
if key != "node:__internal_head__" and key.startswith("node:")
]
assert len(node_ip_keys) == 1, (
"Zero or multiple node IP keys found in node resources: %s",
node_ip_keys,
)
node_ip_key = node_ip_keys[0]
node_ip = node_ip_key.split(":")[1]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources[device_str]) // world_size
if dp_master_ip_key in node_resources:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {dp_master_ip}"
)
for i in range(local_engine_count):
bundles = [
{device_str: 1.0, "node:" + dp_master_ip: 0.001}
] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
dp_size_available = (
int(node_resources[device_str]) // world_size
if device_str in node_resources
else 0
)
if node_ip == dp_master_ip:
if dp_size_available < dp_size_local:
raise ValueError(
"Not enough resources to allocate %s DP ranks "
"on DP master node %s, possible to fit %s DP ranks",
dp_size_local,
dp_master_ip,
dp_size_available,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
dp_size_to_allocate = dp_size_local
elif strict_local_size:
if dp_size_available < dp_size_local:
logger.info(
"Skipping node %s as %s DP ranks could not fit, "
"possible to fit %s DP ranks",
node_ip,
dp_size_local,
dp_size_available,
)
continue
dp_size_to_allocate = dp_size_local
else:
for i in range(available_engine_count):
if len(placement_groups) == num_pg_to_create:
break
bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
if len(placement_groups) < num_pg_to_create:
dp_size_to_allocate = dp_size_available
for i in range(dp_size_to_allocate):
bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [
{"CPU": 1.0}
]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
if len(placement_groups) < dp_size:
raise ValueError(
f"Not enough resources to allocate {num_pg_to_create} "
f"Not enough resources to allocate {dp_size} "
"placement groups, only created "
f"{len(placement_groups)} placement groups. "
"Available resources: "