diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 88d70acb79..986d1b4074 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, StoreBoolean +from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -1245,6 +1245,18 @@ class EngineArgs: cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, ) + + # Get the current placement group if Ray is initialized and + # we are in a Ray actor. If so, then the placement group will be + # passed to spawned processes. + placement_group = None + if is_in_ray_actor(): + import ray + + # This call initializes Ray automatically if it is not initialized, + # but we should not do this here. + placement_group = ray.util.get_current_placement_group() + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1257,6 +1269,7 @@ class EngineArgs: self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, + placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index a4a5b3f938..380b672c36 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -16,7 +16,7 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import _check_multiproc_method, get_mp_context, run_method +from vllm.utils import _maybe_force_spawn, get_mp_context, run_method logger = init_logger(__name__) @@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config): in a multiprocessing environment. This should be called by the parent process before worker processes are created""" - _check_multiproc_method() + _maybe_force_spawn() # Configure thread parallelism if OMP_NUM_THREADS isn't set # diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index a7042ca8df..b7222f26f6 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -284,8 +284,9 @@ def initialize_ray_cluster( assert_ray_available() from vllm.platforms import current_platform - # Connect to a ray cluster. - if current_platform.is_rocm() or current_platform.is_xpu(): + if ray.is_initialized(): + logger.info("Ray is already initialized. Skipping Ray initialization.") + elif current_platform.is_rocm() or current_platform.is_xpu(): # Try to connect existing ray instance and create a new one if not found try: ray.init("auto", ignore_reinit_error=True) @@ -299,19 +300,21 @@ def initialize_ray_cluster( else: ray.init(address=ray_address, ignore_reinit_error=True) - if parallel_config.placement_group: - # Placement group is already set. - return - device_str = current_platform.ray_device_key if not device_str: raise ValueError( f"current platform {current_platform.device_name} does not " "support ray.") - # Create placement group for worker processes - current_placement_group = ray.util.get_current_placement_group() + # Create or get the placement group for worker processes + if parallel_config.placement_group: + current_placement_group = parallel_config.placement_group + else: + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + logger.info("Using the existing placement group") + # We are in a placement group bundles = current_placement_group.bundle_specs # Verify that we can use the placement group. @@ -331,6 +334,8 @@ def initialize_ray_cluster( f"Required number of devices: {parallel_config.world_size}. " f"Total number of devices: {device_bundles}.") else: + logger.info("No current placement group found. " + "Creating a new placement group.") num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) # Log a warning message and delay resource allocation failure response. # Avoid immediate rejection to allow user-initiated placement group diff --git a/vllm/utils.py b/vllm/utils.py index 9bc081890b..cb375f8ff3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: ctx.destroy(linger=0) -def _check_multiproc_method(): - if (cuda_is_initialized() - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): - logger.warning("CUDA was previously initialized. We must use " - "the `spawn` multiprocessing start method. Setting " - "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "See https://docs.vllm.ai/en/latest/getting_started/" - "troubleshooting.html#python-multiprocessing " - "for more information.") +def is_in_ray_actor(): + """Check if we are in a Ray actor.""" + + try: + import ray + return (ray.is_initialized() + and ray.get_runtime_context().get_actor_id() is not None) + except ImportError: + return False + + +def _maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reason = None + if cuda_is_initialized(): + reason = "CUDA is initialized" + elif is_in_ray_actor(): + reason = "In a Ray actor and can only be spawned" + + if reason is not None: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/getting_started/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reason: %s", reason) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" def get_mp_context(): - _check_multiproc_method() + """Get a multiprocessing context with a particular method (spawn or fork). + By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to + determine the multiprocessing method (default is fork). However, under + certain conditions, we may enforce spawn and override the value of + VLLM_WORKER_MULTIPROC_METHOD. + """ + _maybe_force_spawn() mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD return multiprocessing.get_context(mp_method)