[Misc] Getting and passing ray runtime_env to workers (#22040)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-08-01 23:54:40 -07:00
committed by GitHub
parent d3a6f2120b
commit 4ac8437352
6 changed files with 77 additions and 13 deletions

View File

@ -36,3 +36,36 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
assert deep_compare(normal_config_dict, empty_config_dict), (
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
" should be equivalent")
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
# In testing, this method needs to be nested inside as ray does not
# see the test module.
def create_config():
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True)
return engine_args.create_engine_config()
config = create_config()
parallel_config = config.parallel_config
assert parallel_config.ray_runtime_env is None
import ray
ray.init()
runtime_env = {
"env_vars": {
"TEST_ENV_VAR": "test_value",
},
}
config_ref = ray.remote(create_config).options(
runtime_env=runtime_env).remote()
config = ray.get(config_ref)
parallel_config = config.parallel_config
assert parallel_config.ray_runtime_env is not None
assert parallel_config.ray_runtime_env.env_vars().get(
"TEST_ENV_VAR") == "test_value"
ray.shutdown()

View File

@ -57,6 +57,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
if TYPE_CHECKING:
from _typeshed import DataclassInstance
from ray.runtime_env import RuntimeEnv
from ray.util.placement_group import PlacementGroup
from transformers.configuration_utils import PretrainedConfig
@ -74,6 +75,7 @@ if TYPE_CHECKING:
else:
DataclassInstance = Any
PlacementGroup = Any
RuntimeEnv = Any
PretrainedConfig = Any
ExecutorBase = Any
QuantizationConfig = Any
@ -2098,6 +2100,9 @@ class ParallelConfig:
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
ray_runtime_env: Optional["RuntimeEnv"] = None
"""Ray runtime environment to pass to distributed workers."""
placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""

View File

@ -36,6 +36,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
@ -1099,6 +1100,15 @@ class EngineArgs:
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
)
ray_runtime_env = None
if is_ray_initialized():
# Ray Serve LLM calls `create_engine_config` in the context
# of a Ray task, therefore we check is_ray_initialized()
# as opposed to is_in_ray_actor().
import ray
ray_runtime_env = ray.get_runtime_context().runtime_env
logger.info("Using ray runtime env: %s", ray_runtime_env)
# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
@ -1211,6 +1221,7 @@ class EngineArgs:
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,
ray_runtime_env=ray_runtime_env,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,

View File

@ -295,9 +295,12 @@ def initialize_ray_cluster(
logger.warning(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources.")
ray.init(address=ray_address, num_gpus=parallel_config.world_size)
ray.init(address=ray_address,
num_gpus=parallel_config.world_size,
runtime_env=parallel_config.ray_runtime_env)
else:
ray.init(address=ray_address)
ray.init(address=ray_address,
runtime_env=parallel_config.ray_runtime_env)
device_str = current_platform.ray_device_key
if not device_str:

22
vllm/ray/lazy_utils.py Normal file
View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def is_ray_initialized():
"""Check if Ray is initialized."""
try:
import ray
return ray.is_initialized()
except ImportError:
return False
def is_in_ray_actor():
"""Check if we are in a Ray actor."""
try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False

View File

@ -72,6 +72,7 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.ray.lazy_utils import is_in_ray_actor
if TYPE_CHECKING:
from argparse import Namespace
@ -2835,17 +2836,6 @@ def zmq_socket_ctx(
ctx.destroy(linger=linger)
def is_in_ray_actor():
"""Check if we are in a Ray actor."""
try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.