[Misc] Getting and passing ray runtime_env to workers (#22040)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
@ -36,3 +36,36 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
|
||||
" should be equivalent")
|
||||
|
||||
|
||||
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
|
||||
# In testing, this method needs to be nested inside as ray does not
|
||||
# see the test module.
|
||||
def create_config():
|
||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
config = create_config()
|
||||
parallel_config = config.parallel_config
|
||||
assert parallel_config.ray_runtime_env is None
|
||||
|
||||
import ray
|
||||
ray.init()
|
||||
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
},
|
||||
}
|
||||
|
||||
config_ref = ray.remote(create_config).options(
|
||||
runtime_env=runtime_env).remote()
|
||||
|
||||
config = ray.get(config_ref)
|
||||
parallel_config = config.parallel_config
|
||||
assert parallel_config.ray_runtime_env is not None
|
||||
assert parallel_config.ray_runtime_env.env_vars().get(
|
||||
"TEST_ENV_VAR") == "test_value"
|
||||
|
||||
ray.shutdown()
|
||||
|
@ -57,6 +57,7 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@ -74,6 +75,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
PlacementGroup = Any
|
||||
RuntimeEnv = Any
|
||||
PretrainedConfig = Any
|
||||
ExecutorBase = Any
|
||||
QuantizationConfig = Any
|
||||
@ -2098,6 +2100,9 @@ class ParallelConfig:
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
ray_runtime_env: Optional["RuntimeEnv"] = None
|
||||
"""Ray runtime environment to pass to distributed workers."""
|
||||
|
||||
placement_group: Optional["PlacementGroup"] = None
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
|
@ -36,6 +36,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.ray.lazy_utils import is_ray_initialized
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
@ -1099,6 +1100,15 @@ class EngineArgs:
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
if is_ray_initialized():
|
||||
# Ray Serve LLM calls `create_engine_config` in the context
|
||||
# of a Ray task, therefore we check is_ray_initialized()
|
||||
# as opposed to is_in_ray_actor().
|
||||
import ray
|
||||
ray_runtime_env = ray.get_runtime_context().runtime_env
|
||||
logger.info("Using ray runtime env: %s", ray_runtime_env)
|
||||
|
||||
# Get the current placement group if Ray is initialized and
|
||||
# we are in a Ray actor. If so, then the placement group will be
|
||||
# passed to spawned processes.
|
||||
@ -1211,6 +1221,7 @@ class EngineArgs:
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
ray_runtime_env=ray_runtime_env,
|
||||
placement_group=placement_group,
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
|
@ -295,9 +295,12 @@ def initialize_ray_cluster(
|
||||
logger.warning(
|
||||
"No existing RAY instance detected. "
|
||||
"A new instance will be launched with current node resources.")
|
||||
ray.init(address=ray_address, num_gpus=parallel_config.world_size)
|
||||
ray.init(address=ray_address,
|
||||
num_gpus=parallel_config.world_size,
|
||||
runtime_env=parallel_config.ray_runtime_env)
|
||||
else:
|
||||
ray.init(address=ray_address)
|
||||
ray.init(address=ray_address,
|
||||
runtime_env=parallel_config.ray_runtime_env)
|
||||
|
||||
device_str = current_platform.ray_device_key
|
||||
if not device_str:
|
||||
|
22
vllm/ray/lazy_utils.py
Normal file
22
vllm/ray/lazy_utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
def is_ray_initialized():
|
||||
"""Check if Ray is initialized."""
|
||||
try:
|
||||
import ray
|
||||
return ray.is_initialized()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def is_in_ray_actor():
|
||||
"""Check if we are in a Ray actor."""
|
||||
|
||||
try:
|
||||
import ray
|
||||
return (ray.is_initialized()
|
||||
and ray.get_runtime_context().get_actor_id() is not None)
|
||||
except ImportError:
|
||||
return False
|
@ -72,6 +72,7 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
@ -2835,17 +2836,6 @@ def zmq_socket_ctx(
|
||||
ctx.destroy(linger=linger)
|
||||
|
||||
|
||||
def is_in_ray_actor():
|
||||
"""Check if we are in a Ray actor."""
|
||||
|
||||
try:
|
||||
import ray
|
||||
return (ray.is_initialized()
|
||||
and ray.get_runtime_context().get_actor_id() is not None)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
|
Reference in New Issue
Block a user