mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[core] clean up executor class hierarchy between v1 and v0 (#12171)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -79,16 +79,6 @@ class ExecutorBase(ABC):
|
||||
b = min([r[1] for r in results])
|
||||
return a, b
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
For V1 compatibility.
|
||||
"""
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
|
@ -1,63 +1,92 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""Abstract class for executors."""
|
||||
class Executor(ExecutorBase):
|
||||
"""
|
||||
Abstract class for v1 executors, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in ExecutorBase"""
|
||||
|
||||
@staticmethod
|
||||
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
|
||||
executor_class: Type[Executor]
|
||||
parallel_config = vllm_config.parallel_config
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend is None:
|
||||
# If the user does not specify the distributed executor backend,
|
||||
# we will choose the backend based on the world size.
|
||||
if parallel_config.world_size > 1:
|
||||
distributed_executor_backend = "mp"
|
||||
else:
|
||||
distributed_executor_backend = "uni"
|
||||
|
||||
if distributed_executor_backend == "ray":
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
elif distributed_executor_backend == "uni":
|
||||
executor_class = UniProcExecutor
|
||||
elif distributed_executor_backend == "external_launcher":
|
||||
# TODO: make v1 scheduling deterministic
|
||||
# to support external launcher
|
||||
executor_class = ExecutorWithExternalLauncher
|
||||
else:
|
||||
assert (distributed_executor_backend is None)
|
||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||
executor_class = UniprocExecutor
|
||||
raise ValueError("Unknown distributed executor backend: "
|
||||
f"{distributed_executor_backend}")
|
||||
return executor_class
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
@abstractmethod
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
raise NotImplementedError
|
||||
output = self.collective_rpc("determine_available_memory")
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
return min(output)
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
raise NotImplementedError
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
for x in output:
|
||||
assert x == output[0]
|
||||
return output[0]
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
raise NotImplementedError
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
@abstractmethod
|
||||
def profile(self, is_start: bool = True):
|
||||
raise NotImplementedError
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
raise NotImplementedError
|
||||
class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
pass
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
pass
|
||||
|
||||
|
||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
pass
|
||||
|
@ -25,8 +25,6 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -37,7 +35,7 @@ POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
@ -55,9 +53,6 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
assert self.world_size == tensor_parallel_size, (
|
||||
@ -82,7 +77,8 @@ class MultiprocExecutor(Executor):
|
||||
# Create workers
|
||||
self.workers: List[WorkerProcHandle] = []
|
||||
for rank in range(self.world_size):
|
||||
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
|
||||
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
scheduler_output_handle)
|
||||
self.workers.append(worker)
|
||||
@ -93,34 +89,6 @@ class MultiprocExecutor(Executor):
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the available memory (in bytes) for KV cache by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
memory_sizes = self.collective_rpc("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
return min(memory_sizes)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model by invoking the underlying worker.
|
||||
"""
|
||||
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
@ -172,18 +140,6 @@ class MultiprocExecutor(Executor):
|
||||
# Re-raise any other exceptions
|
||||
raise e
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
model_output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))[0]
|
||||
return model_output
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
return
|
||||
|
||||
def _ensure_worker_termination(self):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
received termination requests. Waits for processing, then sends
|
||||
|
@ -1,344 +0,0 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
|
||||
initialize_ray_cluster, ray)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayExecutor(Executor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
initialize_ray_cluster(self.parallel_config)
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# A list of workers to run a model.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
# Skip bundles that don't have GPUs,
|
||||
# as each worker needs one GPU.
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
worker_to_ip = dict(zip(self.workers, worker_ips))
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first. This is simply a tiebreaker to make
|
||||
sure the workers are sorted in a deterministic way.
|
||||
"""
|
||||
ip = worker_to_ip[worker]
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips)
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP` or "
|
||||
"`HOST_IP` environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
"VLLM_USE_V1":
|
||||
str(int(envs.VLLM_USE_V1)),
|
||||
**({
|
||||
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
|
||||
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
self._run_workers("initialize")
|
||||
self._run_workers("load_model")
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
# If nsight profiling is enabled, we need to set the profiling
|
||||
# configuration for the ray workers as runtime env.
|
||||
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
||||
runtime_env.update({
|
||||
"nsight": {
|
||||
"t": "cuda,cudnn,cublas",
|
||||
"o": "'worker_process_%p'",
|
||||
"cuda-graph-trace": "node",
|
||||
}
|
||||
})
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
return self._env_vars_for_all_workers
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Return worker init args for a given rank.
|
||||
"""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the available GPU memory in bytes.
|
||||
|
||||
This invokes `determine_available_memory` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
"""
|
||||
|
||||
memory_sizes = self._run_workers("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
return min(memory_sizes)
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV cache in all workers.
|
||||
"""
|
||||
self._run_workers("initialize_cache", kv_cache_config)
|
||||
self._run_workers("compile_or_warm_up_model")
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model
|
||||
|
||||
This invokes `get_kv_cache_spec` on each worker and asserts that
|
||||
they are identical. The KVCacheSpec is then returned.
|
||||
"""
|
||||
kv_cache_specs = self._run_workers("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
Args:
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
count = len(self.workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 0, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, 0, None)
|
||||
|
||||
ray_worker_refs = [
|
||||
worker.execute_method.remote( # type: ignore[attr-defined]
|
||||
method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
return ray.get(ray_worker_refs)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
# Only the first worker (with rank 0) returns the execution result.
|
||||
# Others return None.
|
||||
output = ray.get(self.forward_dag.execute(scheduler_output))[0]
|
||||
return output
|
||||
|
||||
def profile(self, is_start=True):
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.forward_dag = None
|
||||
|
||||
def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
def _check_ray_compiled_graph_installation(self):
|
||||
import pkg_resources
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.39")
|
||||
current_version = version.parse(
|
||||
pkg_resources.get_distribution("ray").version)
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
import importlib.util
|
||||
raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
|
||||
if raycg is None:
|
||||
raise ValueError("Ray Compiled Graph is not installed. "
|
||||
"Run `pip install ray[adag]` to install it.")
|
||||
|
||||
cupy_spec = importlib.util.find_spec("cupy")
|
||||
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
|
||||
raise ValueError(
|
||||
"cupy is not installed but required since "
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
|
||||
"Run `pip install ray[adag]` and check cupy installation.")
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
assert self.parallel_config.use_ray
|
||||
self._check_ray_compiled_graph_installation()
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
|
||||
with InputNode() as input_batches:
|
||||
outputs = [
|
||||
worker.execute_model.bind( # type: ignore[attr-defined]
|
||||
input_batches) for worker in self.workers
|
||||
]
|
||||
forward_dag = MultiOutputNode(outputs)
|
||||
|
||||
return forward_dag.experimental_compile()
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
@ -1,280 +0,0 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_ip
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
PG_WAIT_TIMEOUT = 60
|
||||
|
||||
try:
|
||||
import ray
|
||||
from ray.util import placement_group_table
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
try:
|
||||
from ray._private.state import available_resources_per_node
|
||||
except ImportError:
|
||||
# Ray 2.9.x doesn't expose `available_resources_per_node`
|
||||
from ray._private.state import state as _state
|
||||
available_resources_per_node = _state._available_resources_per_node
|
||||
|
||||
class RayWorkerWrapper(WorkerWrapperBase):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
# Since the compiled DAG runs a main execution
|
||||
# in a different thread that calls cuda.set_device.
|
||||
# The flag indicates is set_device is called on
|
||||
# that thread. It will be removed soon.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||
node_id = ray.get_runtime_context().get_node_id()
|
||||
device_key = current_platform.ray_device_key
|
||||
if not device_key:
|
||||
raise RuntimeError("current platform %s does not support ray.",
|
||||
current_platform.device_name)
|
||||
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
|
||||
)[device_key]
|
||||
return node_id, gpu_ids
|
||||
|
||||
def setup_device_if_necessary(self):
|
||||
# TODO(swang): This is needed right now because Ray CG executes
|
||||
# on a background thread, so we need to reset torch's current
|
||||
# device.
|
||||
# We can remove this API after it is fixed in compiled graph.
|
||||
import torch
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self.setup_device_if_necessary()
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
output = self.worker.model_runner.execute_model(scheduler_output)
|
||||
return output
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
except ImportError as e:
|
||||
ray = None # type: ignore
|
||||
ray_import_err = e
|
||||
RayWorkerWrapper = None # type: ignore
|
||||
|
||||
|
||||
def ray_is_available() -> bool:
|
||||
"""Returns True if Ray is available."""
|
||||
return ray is not None
|
||||
|
||||
|
||||
def assert_ray_available():
|
||||
"""
|
||||
Raise an exception if Ray is not available.
|
||||
"""
|
||||
if ray is None:
|
||||
raise ValueError("Failed to import Ray, please install Ray with "
|
||||
"`pip install ray`.") from ray_import_err
|
||||
|
||||
|
||||
def _verify_bundles(placement_group: "PlacementGroup",
|
||||
parallel_config: ParallelConfig, device_str: str):
|
||||
"""
|
||||
Verify a given placement group has bundles located in the right place.
|
||||
|
||||
There are 2 rules.
|
||||
- Warn if all tensor parallel workers cannot fit in a single node.
|
||||
- Fail if driver node is not included in a placement group.
|
||||
|
||||
Args:
|
||||
placement_group: The placement group to verify.
|
||||
parallel_config: The parallel configuration.
|
||||
device_str: The required device.
|
||||
"""
|
||||
assert ray.is_initialized(), (
|
||||
"Ray is not initialized although distributed-executor-backend is ray.")
|
||||
pg_data = placement_group_table(placement_group)
|
||||
# bundle_idx -> node_id
|
||||
bundle_to_node_ids = pg_data["bundles_to_node_id"]
|
||||
# bundle_idx -> bundle (e.g., {"GPU": 1})
|
||||
bundles = pg_data["bundles"]
|
||||
# node_id -> List of bundle (e.g., {"GPU": 1})
|
||||
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
|
||||
|
||||
for bundle_idx, node_id in bundle_to_node_ids.items():
|
||||
node_id_to_bundle[node_id].append(bundles[bundle_idx])
|
||||
driver_node_id = ray.get_runtime_context().get_node_id()
|
||||
|
||||
if driver_node_id not in node_id_to_bundle:
|
||||
raise RuntimeError(
|
||||
f"driver node id {driver_node_id} is not included in a placement "
|
||||
f"group {placement_group.id}. Node id -> bundles "
|
||||
f"{node_id_to_bundle}. "
|
||||
"You don't have enough GPUs available in a current node. Check "
|
||||
"`ray status` to see if you have available GPUs in a node "
|
||||
f"{driver_node_id} before starting an vLLM engine.")
|
||||
|
||||
for node_id, bundles in node_id_to_bundle.items():
|
||||
if len(bundles) < parallel_config.tensor_parallel_size:
|
||||
logger.warning(
|
||||
"tensor_parallel_size=%d "
|
||||
"is bigger than a reserved number of %ss (%d "
|
||||
"%ss) in a node %s. Tensor parallel workers can be "
|
||||
"spread out to 2+ nodes which can degrade the performance "
|
||||
"unless you have fast interconnect across nodes, like "
|
||||
"Infiniband. To resolve this issue, make sure you have more "
|
||||
"than %d GPUs available at each node.",
|
||||
parallel_config.tensor_parallel_size, device_str, len(bundles),
|
||||
device_str, node_id, parallel_config.tensor_parallel_size)
|
||||
|
||||
|
||||
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
|
||||
"""Wait until a placement group is ready.
|
||||
|
||||
It prints the informative log messages if the placement group is
|
||||
not created within time.
|
||||
|
||||
"""
|
||||
# Wait until PG is ready - this will block until all
|
||||
# requested resources are available, and will timeout
|
||||
# if they cannot be provisioned.
|
||||
placement_group_specs = current_placement_group.bundle_specs
|
||||
|
||||
s = time.time()
|
||||
pg_ready_ref = current_placement_group.ready()
|
||||
wait_interval = 10
|
||||
while time.time() - s < PG_WAIT_TIMEOUT:
|
||||
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
|
||||
if len(ready) > 0:
|
||||
break
|
||||
|
||||
# Exponential backoff for warning print.
|
||||
wait_interval *= 2
|
||||
logger.info(
|
||||
"Waiting for creating a placement group of specs for "
|
||||
"%d seconds. specs=%s. Check "
|
||||
"`ray status` to see if you have enough resources.",
|
||||
int(time.time() - s), placement_group_specs)
|
||||
|
||||
try:
|
||||
ray.get(pg_ready_ref, timeout=0)
|
||||
except ray.exceptions.GetTimeoutError:
|
||||
raise ValueError(
|
||||
"Cannot provide a placement group of "
|
||||
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
|
||||
"`ray status` to make sure the cluster has enough resources."
|
||||
) from None
|
||||
|
||||
|
||||
def initialize_ray_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
ray_address: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the distributed cluster with Ray.
|
||||
|
||||
it will connect to the Ray cluster and create a placement group
|
||||
for the workers, which includes the specification of the resources
|
||||
for each distributed worker.
|
||||
|
||||
Args:
|
||||
parallel_config: The configurations for parallel execution.
|
||||
ray_address: The address of the Ray cluster. If None, uses
|
||||
the default Ray cluster address.
|
||||
"""
|
||||
assert_ray_available()
|
||||
|
||||
# Connect to a ray cluster.
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
# Try to connect existing ray instance and create a new one if not found
|
||||
try:
|
||||
ray.init("auto")
|
||||
except ConnectionError:
|
||||
logger.warning(
|
||||
"No existing RAY instance detected. "
|
||||
"A new instance will be launched with current node resources.")
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if parallel_config.placement_group:
|
||||
# Placement group is already set.
|
||||
return
|
||||
|
||||
device_str = current_platform.ray_device_key
|
||||
if not device_str:
|
||||
raise ValueError(
|
||||
f"current platform {current_platform.device_name} does not "
|
||||
"support ray.")
|
||||
# Create placement group for worker processes
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
bundles = current_placement_group.bundle_specs
|
||||
# Verify that we can use the placement group.
|
||||
device_bundles = 0
|
||||
for bundle in bundles:
|
||||
bundle_devices = bundle.get(device_str, 0)
|
||||
if bundle_devices > 1:
|
||||
raise ValueError(
|
||||
"Placement group bundle cannot have more than 1 "
|
||||
f"{device_str}.")
|
||||
if bundle_devices:
|
||||
device_bundles += 1
|
||||
if parallel_config.world_size > device_bundles:
|
||||
raise ValueError(
|
||||
f"The number of required {device_str}s exceeds the total "
|
||||
f"number of available {device_str}s in the placement group."
|
||||
f"Required number of devices: {parallel_config.world_size}. "
|
||||
f"Total number of devices: {device_bundles}.")
|
||||
else:
|
||||
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
|
||||
if parallel_config.world_size > num_devices_in_cluster:
|
||||
raise ValueError(
|
||||
f"The number of required {device_str}s exceeds the total "
|
||||
f"number of available {device_str}s in the placement group.")
|
||||
# Create a new placement group
|
||||
placement_group_specs: List[Dict[str, float]] = ([{
|
||||
device_str: 1.0
|
||||
} for _ in range(parallel_config.world_size)])
|
||||
|
||||
# vLLM engine is also a worker to execute model with an accelerator,
|
||||
# so it requires to have the device in a current node. Check if
|
||||
# the current node has at least one device.
|
||||
current_ip = get_ip()
|
||||
current_node_id = ray.get_runtime_context().get_node_id()
|
||||
current_node_resource = available_resources_per_node()[current_node_id]
|
||||
if current_node_resource.get(device_str, 0) < 1:
|
||||
raise ValueError(
|
||||
f"Current node has no {device_str} available. "
|
||||
f"{current_node_resource=}. vLLM engine cannot start without "
|
||||
f"{device_str}. Make sure you have at least 1 {device_str} "
|
||||
f"available in a node {current_node_id=} {current_ip=}.")
|
||||
# This way, at least bundle is required to be created in a current
|
||||
# node.
|
||||
placement_group_specs[0][f"node:{current_ip}"] = 0.001
|
||||
|
||||
# By default, Ray packs resources as much as possible.
|
||||
current_placement_group = ray.util.placement_group(
|
||||
placement_group_specs, strategy="PACK")
|
||||
_wait_until_pg_ready(current_placement_group)
|
||||
|
||||
assert current_placement_group is not None
|
||||
_verify_bundles(current_placement_group, parallel_config, device_str)
|
||||
# Set the placement group in the parallel config
|
||||
parallel_config.placement_group = current_placement_group
|
@ -1,88 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UniprocExecutor(Executor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.worker: Worker = self._create_worker()
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
def _create_worker(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Worker:
|
||||
"""Return worker init args for a given rank."""
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return Worker(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Determine the available memory (in bytes) for KV cache by invoking
|
||||
the underlying worker.
|
||||
"""
|
||||
return self.worker.determine_available_memory()
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""Get all kv cache needed by the model by invoking the underlying
|
||||
worker.
|
||||
"""
|
||||
return self.worker.get_kv_cache_spec()
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
self.worker.initialize_cache(kv_cache_config)
|
||||
self.worker.compile_or_warm_up_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
output = self.worker.execute_model(scheduler_output)
|
||||
assert output is not None
|
||||
return output
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.worker.profile(is_start)
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def check_health(self) -> None:
|
||||
# UniprocExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
Reference in New Issue
Block a user