[core] clean up executor class hierarchy between v1 and v0 (#12171)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-18 14:35:15 +08:00
committed by GitHub
parent 02798ecabe
commit 6d0e3d3724
6 changed files with 61 additions and 798 deletions

View File

@ -79,16 +79,6 @@ class ExecutorBase(ABC):
b = min([r[1] for r in results])
return a, b
def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
For V1 compatibility.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""

View File

@ -1,63 +1,92 @@
from abc import ABC, abstractmethod
from typing import Type
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
class Executor(ABC):
"""Abstract class for executors."""
class Executor(ExecutorBase):
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""
@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend)
if distributed_executor_backend is None:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if parallel_config.world_size > 1:
distributed_executor_backend = "mp"
else:
distributed_executor_backend = "uni"
if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}")
return executor_class
@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError
@abstractmethod
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")
@abstractmethod
def determine_available_memory(self) -> int: # in bytes
raise NotImplementedError
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)
@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
raise NotImplementedError
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]
@abstractmethod
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
raise NotImplementedError
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
return output[0]
@abstractmethod
def profile(self, is_start: bool = True):
raise NotImplementedError
self.collective_rpc("profile", args=(is_start, ))
@abstractmethod
def shutdown(self):
pass
@abstractmethod
def check_health(self) -> None:
raise NotImplementedError
class UniProcExecutor(UniProcExecutorV0, Executor):
pass
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
pass

View File

@ -25,8 +25,6 @@ from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -37,7 +35,7 @@ POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
@ -55,9 +53,6 @@ class MultiprocExecutor(Executor):
signal.signal(signal.SIGUSR1, sigusr1_handler)
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
@ -82,7 +77,8 @@ class MultiprocExecutor(Executor):
# Create workers
self.workers: List[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
@ -93,34 +89,6 @@ class MultiprocExecutor(Executor):
for w in self.workers:
w.worker_response_mq.wait_until_ready()
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int:
"""
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
memory_sizes = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(memory_sizes)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
@ -172,18 +140,6 @@ class MultiprocExecutor(Executor):
# Re-raise any other exceptions
raise e
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
args=(scheduler_output, ))[0]
return model_output
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return
def _ensure_worker_termination(self):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends

View File

@ -1,344 +0,0 @@
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
initialize_ray_cluster, ray)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# A list of workers to run a model.
self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_to_ip = dict(zip(self.workers, worker_ips))
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips)
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
"VLLM_USE_V1":
str(int(envs.VLLM_USE_V1)),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("initialize")
self._run_workers("load_model")
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""
Return worker init args for a given rank.
"""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
def determine_available_memory(self) -> int:
"""
Determine the available GPU memory in bytes.
This invokes `determine_available_memory` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
"""
memory_sizes = self._run_workers("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(memory_sizes)
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV cache in all workers.
"""
self._run_workers("initialize_cache", kv_cache_config)
self._run_workers("compile_or_warm_up_model")
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model
This invokes `get_kv_cache_spec` on each worker and asserts that
they are identical. The KVCacheSpec is then returned.
"""
kv_cache_specs = self._run_workers("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]
def _run_workers(
self,
method: str,
*args,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> Any:
"""
Runs the given method on all workers. Can be used in the following
ways:
Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 0, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 0, None)
ray_worker_refs = [
worker.execute_method.remote( # type: ignore[attr-defined]
method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
return ray.get(ray_worker_refs)
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
return output
def profile(self, is_start=True):
raise NotImplementedError
def shutdown(self):
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def check_health(self) -> None:
logger.debug("Called check_health.")
def _check_ray_compiled_graph_installation(self):
import pkg_resources
from packaging import version
required_version = version.parse("2.39")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
if raycg is None:
raise ValueError("Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation.")
def _compiled_ray_dag(self):
assert self.parallel_config.use_ray
self._check_ray_compiled_graph_installation()
from ray.dag import InputNode, MultiOutputNode
with InputNode() as input_batches:
outputs = [
worker.execute_model.bind( # type: ignore[attr-defined]
input_batches) for worker in self.workers
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile()
def __del__(self):
self.shutdown()

View File

@ -1,280 +0,0 @@
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import get_ip
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 60
try:
import ray
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
class RayWorkerWrapper(WorkerWrapperBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread. It will be removed soon.
self.compiled_dag_cuda_device_set = False
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
device_key = current_platform.ray_device_key
if not device_key:
raise RuntimeError("current platform %s does not support ray.",
current_platform.device_name)
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
)[device_key]
return node_id, gpu_ids
def setup_device_if_necessary(self):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
return output
ray_import_err = None
except ImportError as e:
ray = None # type: ignore
ray_import_err = e
RayWorkerWrapper = None # type: ignore
def ray_is_available() -> bool:
"""Returns True if Ray is available."""
return ray is not None
def assert_ray_available():
"""
Raise an exception if Ray is not available.
"""
if ray is None:
raise ValueError("Failed to import Ray, please install Ray with "
"`pip install ray`.") from ray_import_err
def _verify_bundles(placement_group: "PlacementGroup",
parallel_config: ParallelConfig, device_str: str):
"""
Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
Args:
placement_group: The placement group to verify.
parallel_config: The parallel configuration.
device_str: The required device.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()
if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")
for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node.",
parallel_config.tensor_parallel_size, device_str, len(bundles),
device_str, node_id, parallel_config.tensor_parallel_size)
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs
s = time.time()
pg_ready_ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
if len(ready) > 0:
break
# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)
try:
ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None
def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
assert_ray_available()
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto")
except ConnectionError:
logger.warning(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources.")
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
if parallel_config.placement_group:
# Placement group is already set.
return
device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
device_bundles = 0
for bundle in bundles:
bundle_devices = bundle.get(device_str, 0)
if bundle_devices > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_devices:
device_bundles += 1
if parallel_config.world_size > device_bundles:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group."
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_devices_in_cluster:
raise ValueError(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
} for _ in range(parallel_config.world_size)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str, 0) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs, strategy="PACK")
_wait_until_pg_ready(current_placement_group)
assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group

View File

@ -1,88 +0,0 @@
import os
from typing import Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class UniprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.worker: Worker = self._create_worker()
self.worker.init_device()
self.worker.load_model()
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Worker:
"""Return worker init args for a given rank."""
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return Worker(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
def determine_available_memory(self) -> int:
"""Determine the available memory (in bytes) for KV cache by invoking
the underlying worker.
"""
return self.worker.determine_available_memory()
def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get all kv cache needed by the model by invoking the underlying
worker.
"""
return self.worker.get_kv_cache_spec()
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
self.worker.initialize_cache(kv_cache_config)
self.worker.compile_or_warm_up_model()
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
output = self.worker.execute_model(scheduler_output)
assert output is not None
return output
def profile(self, is_start: bool = True):
self.worker.profile(is_start)
def shutdown(self):
pass
def check_health(self) -> None:
# UniprocExecutor will always be healthy as long as
# it's running.
return