[Misc] Better RayExecutor and multiprocessing compatibility (#14705)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Cody Yu
2025-03-20 19:27:46 -07:00
committed by GitHub
parent 11b986b3fb
commit 5df2da5b97
4 changed files with 67 additions and 21 deletions

View File

@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@ -1245,6 +1245,18 @@ class EngineArgs:
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)
# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
placement_group = None
if is_in_ray_actor():
import ray
# This call initializes Ray automatically if it is not initialized,
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@ -1257,6 +1269,7 @@ class EngineArgs:
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,

View File

@ -16,7 +16,7 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method
logger = init_logger(__name__)
@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
_check_multiproc_method()
_maybe_force_spawn()
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#

View File

@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available()
from vllm.platforms import current_platform
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto", ignore_reinit_error=True)
@ -299,19 +300,21 @@ def initialize_ray_cluster(
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")
# Create placement group for worker processes
# Create or get the placement group for worker processes
if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
else:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
logger.info("Using the existing placement group")
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
@ -331,6 +334,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group

View File

@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0)
def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
def is_in_ray_actor():
"""Check if we are in a Ray actor."""
try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reason = None
if cuda_is_initialized():
reason = "CUDA is initialized"
elif is_in_ray_actor():
reason = "In a Ray actor and can only be spawned"
if reason is not None:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information.")
"for more information. Reason: %s", reason)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
_check_multiproc_method()
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)