[core] platform agnostic executor via collective_rpc (#11256)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-15 13:45:21 +08:00
committed by GitHub
parent f218f9c24d
commit ad34c0df0f
43 changed files with 851 additions and 2641 deletions

View File

@ -1,12 +1,13 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple
import pytest
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
from vllm.executor.uniproc_executor import UniProcExecutor
from vllm.sampling_params import SamplingParams
@ -14,21 +15,20 @@ class Mock:
...
class CustomGPUExecutor(GPUExecutor):
class CustomUniExecutor(UniProcExecutor):
def execute_model(self, *args, **kwargs):
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().execute_model(*args, **kwargs)
return super().collective_rpc(method, timeout, args, kwargs)
class CustomGPUExecutorAsync(GPUExecutorAsync):
async def execute_model_async(self, *args, **kwargs):
with open(".marker", "w"):
...
return await super().execute_model_async(*args, **kwargs)
CustomUniExecutorAsync = CustomUniExecutor
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model):
engine_args = AsyncEngineArgs(model=model,
distributed_executor_backend=Mock)
AsyncLLMEngine.from_engine_args(engine_args)
with pytest.raises(TypeError):
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
AsyncLLMEngine.from_engine_args(engine_args)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
model=model, distributed_executor_backend=CustomUniExecutor)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker")
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
model=model, distributed_executor_backend=CustomUniExecutorAsync)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

View File

@ -6,16 +6,15 @@ from typing import Any, List, Tuple
import pytest
from vllm.config import VllmConfig
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.worker.worker_base import WorkerWrapperBase
class DummyWorker:
class DummyWorkerWrapper(WorkerWrapperBase):
"""Dummy version of vllm.worker.worker.Worker"""
def __init__(self, rank: int):
self.rank = rank
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)
@ -28,9 +27,10 @@ class DummyWorker:
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
result_handler = ResultHandler()
vllm_config = VllmConfig()
workers = [
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
rank) for rank in range(8)
]
worker_monitor = WorkerMonitor(workers, result_handler)

View File

@ -2,6 +2,7 @@ import asyncio
import os
import socket
from typing import AsyncIterator, Tuple
from unittest.mock import patch
import pytest
import torch
@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder():
def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(
parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

View File

@ -1294,8 +1294,11 @@ class ParallelConfig:
from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray_is_available()
if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference, "
@ -1328,13 +1331,14 @@ class ParallelConfig:
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
"ray", "mp", "uni", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
" subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()

View File

@ -862,12 +862,14 @@ def init_model_parallel_group(
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,

View File

@ -18,9 +18,7 @@ from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
@ -620,69 +618,9 @@ class AsyncLLMEngine(EngineClient):
rt.new_requests_event.set()
@classmethod
def _get_executor_cls(
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)
@classmethod
def from_engine_args(
@ -700,9 +638,6 @@ class AsyncLLMEngine(EngineClient):
executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine.
engine = cls(
vllm_config=engine_config,
@ -1242,23 +1177,12 @@ class AsyncLLMEngine(EngineClient):
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()
async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
self.engine.add_lora(lora_request)

View File

@ -28,8 +28,6 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
@ -442,64 +440,26 @@ class LLMEngine:
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
elif engine_config.parallel_config.world_size > 1:
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
else:
from vllm.executor.hpu_executor import HPUExecutor
executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
return executor_class
@classmethod
@ -1845,27 +1805,17 @@ class LLMEngine:
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None:
self.model_executor.start_profile()
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")
def is_tracing_enabled(self) -> bool:
return self.tracer is not None

View File

@ -20,7 +20,6 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
@ -356,16 +355,10 @@ class MQLLMEngine:
self._errored_with = e
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()
def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()
def signal_handler(*_) -> None:

View File

@ -1,299 +0,0 @@
import os
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_distributed_init_method, get_open_port, make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
#
# Environment variables for CPU executor
#
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
self.parallel_config.tensor_parallel_size)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
ip = "127.0.0.1"
port = get_open_port()
self.distributed_init_method = get_distributed_init_method(ip, port)
is_async = isinstance(self, CPUExecutorAsync)
world_size = self.parallel_config.tensor_parallel_size
result_handler = ResultHandler()
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
self.workers = []
if is_async:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(0, world_size)
]
self.driver_worker = self.workers[0]
self.workers = self.workers[1:]
self.driver_method_invoker = _async_driver_method_invoker
else:
self.driver_worker = self._create_worker()
self.driver_method_invoker = _driver_method_invoker
if world_size != 1:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(1, world_size)
]
self.worker_monitor = None
if world_size != 1 or is_async:
if is_async:
async_worker_list = self.workers + [self.driver_worker]
else:
async_worker_list = self.workers
self.worker_monitor = WorkerMonitor(async_worker_list,
result_handler)
result_handler.start()
self.worker_monitor.start()
self._run_workers("init_device")
self._run_workers("load_model")
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
):
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
assert self.distributed_init_method is not None
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=self.distributed_init_method,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=rank == 0,
)
wrapper.init_worker(**kwargs)
return wrapper.worker
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
driver_worker_output = self.driver_method_invoker(
self.driver_worker, method, *args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_method_invoker(self.driver_worker,
"determine_num_available_blocks")
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if (self.parallel_config.tensor_parallel_size > 1
and self.parallel_worker_tasks is None):
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
)
output = self.driver_method_invoker(self.driver_worker,
"execute_model", execute_model_req)
return output
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
"""
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
self.driver_method_invoker(self.driver_worker, "execute_model", None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool:
return all(self._run_workers("add_lora", lora_request))
def remove_lora(self, lora_id: int) -> bool:
return all(self._run_workers("remove_lora", lora_id))
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self._run_workers(
"pin_lora",
lora_id=lora_id,
))
def list_loras(self) -> Set[int]:
return self.driver_method_invoker(self.driver_worker, "list_loras")
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return all(
self._run_workers(
"add_prompt_adapter",
prompt_adapter_request,
))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return all(
self._run_workers(
"remove_prompt_adapter",
prompt_adapter_id,
))
def list_prompt_adapters(self) -> Set[int]:
return self.driver_method_invoker(self.driver_worker,
"list_prompt_adapters")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return all(self._run_workers(
"pin_prompt_adapter",
prompt_adapter_id,
))
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
):
raise RuntimeError("Worker processes are not running")
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
def start_profile(self) -> None:
self.driver_method_invoker(self.driver_worker, "start_profile")
def stop_profile(self) -> None:
self.driver_method_invoker(self.driver_worker, "stop_profile")
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
self.check_health()
def _driver_method_invoker(driver, method: str, *args, **kwargs):
return getattr(driver, method)(*args, **kwargs)
def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
return driver.execute_method(method, *args, **kwargs).get()

View File

@ -1,212 +0,0 @@
import asyncio
from abc import abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"pin_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)
@abstractmethod
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError
@abstractmethod
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError

View File

@ -1,18 +1,24 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
logger = init_logger(__name__)
class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
An executor is responsible for executing the model on one device,
or it can be a distributed executor
that can execute the model on multiple devices.
"""
@ -40,6 +46,20 @@ class ExecutorBase(ABC):
pass
@abstractmethod
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
The main interface of the executor to run a method on all workers,
with homogeneous arguments.
If the args are heterogeneous, then we can pack them into a list,
and unpack them in the method of every worker, because every worker
knows their own rank.
"""
pass
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
@ -53,58 +73,113 @@ class ExecutorBase(ABC):
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
results = self.collective_rpc("determine_num_available_blocks")
a = min([r[0] for r in results])
b = min([r[1] for r in results])
return a, b
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
def initialize(self, num_gpu_blocks: int) -> None:
"""
raise NotImplementedError
Initialize the KV caches and begin the model execution loop of the
underlying workers.
For V1 compatibility.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 workers.
logger.info("# %s blocks: %d, # CPU blocks: %d",
current_platform.dispatch_key, num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
@abstractmethod
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.collective_rpc("execute_model",
args=(execute_model_req, ))
return output[0]
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("add_lora", args=(lora_request, )))
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("remove_lora", args=(lora_id, )))
@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError # type: ignore
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id, )))
@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
sets = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
@abstractmethod
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("add_prompt_adapter",
args=(prompt_adapter_request, )))
@abstractmethod
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("remove_prompt_adapter",
args=(prompt_adapter_id, )))
@abstractmethod
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError # type: ignore
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("pin_prompt_adapter",
args=(prompt_adapter_id, )))
@abstractmethod
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError
sets = self.collective_rpc("list_prompt_adapters")
for s in sets:
assert (s == sets[0]
), "All workers should have the same prompt adapters."
return sets[0]
def start_profile(self) -> None:
self.collective_rpc("start_profile")
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.collective_rpc("save_sharded_state",
kwargs=dict(path=path,
pattern=pattern,
max_size=max_size))
@abstractmethod
def check_health(self) -> None:
@ -119,15 +194,12 @@ class ExecutorBase(ABC):
def __del__(self):
self.shutdown()
class ExecutorAsyncBase(ExecutorBase):
@abstractmethod
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
output = await make_async(self.execute_model)(execute_model_req)
return output
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
@ -137,3 +209,128 @@ class ExecutorAsyncBase(ExecutorBase):
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()
class DistributedExecutorBase(ExecutorBase):
"""Abstract superclass of distributed executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
super().__init__(*args, **kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
# TODO: unify into collective_rpc
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True)
# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
@abstractmethod
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
return self._run_workers(method, *args, **(kwargs or {}))
@abstractmethod
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
# TODO: simplify and merge with collective_rpc
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError

View File

@ -1,145 +0,0 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
def create_worker(**kwargs):
vllm_config = kwargs.get("vllm_config")
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
wrapper.init_worker(**kwargs)
return wrapper.worker
class GPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
return create_worker(**self._get_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output

View File

@ -1,202 +0,0 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import os
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class HPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self._init_worker()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=rank == 0,
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler
with HabanaMemoryProfiler() as cache_init_m:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
logger.info(msg)
def finish_measurements(self):
self.driver_worker.finish_measurements()
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if log_graph_compilation or log_cpu_fallbacks:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = self.driver_worker.execute_model(execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output

View File

@ -1,32 +1,26 @@
import asyncio
import os
from functools import partial
from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.multiproc_worker_utils import (
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port, make_async,
update_environment_variables)
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
"""Python multiprocessing-based distributed executor"""
uses_ray: bool = False
def _init_executor(self) -> None:
self._check_executor_parameters()
# Create the parallel GPU workers.
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@ -55,15 +49,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
else:
result_handler = ResultHandler()
for rank in range(1, world_size):
worker = ProcessWorkerWrapper(
result_handler,
partial(
create_worker,
**self._get_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
worker = ProcessWorkerWrapper(result_handler,
WorkerWrapperBase,
self.vllm_config, rank)
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
@ -77,32 +65,30 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
all_kwargs = []
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
for i in range(world_size):
local_rank = i
rank = i
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def _check_executor_parameters(self):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
@ -172,15 +158,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for result in parallel_worker_tasks:
result.get()
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None

View File

@ -12,6 +12,7 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context
@ -147,7 +148,8 @@ class ProcessWorkerWrapper:
for handling single-node multi-GPU tensor parallel."""
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
worker_factory: Callable[[VllmConfig, int], Any],
vllm_config: VllmConfig, rank: int) -> None:
self.mp = get_mp_context()
self._task_queue = self.mp.Queue()
self.result_queue = result_handler.result_queue
@ -159,6 +161,8 @@ class ProcessWorkerWrapper:
worker_factory=worker_factory,
task_queue=self._task_queue,
result_queue=self.result_queue,
vllm_config=vllm_config,
rank=rank,
),
daemon=True)
@ -199,9 +203,11 @@ class ProcessWorkerWrapper:
def _run_worker_process(
worker_factory: Callable[[], Any],
worker_factory: Callable[[VllmConfig, int], Any],
task_queue: Queue,
result_queue: Queue,
vllm_config: VllmConfig,
rank: int,
) -> None:
"""Worker process event loop"""
@ -212,7 +218,7 @@ def _run_worker_process(
_add_prefix(sys.stderr, process_name, pid)
# Initialize worker
worker = worker_factory()
worker = worker_factory(vllm_config, rank)
del worker_factory
# Accept tasks from the engine in task_queue

View File

@ -1,26 +0,0 @@
import vllm.envs as envs
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import make_async
logger = init_logger(__name__)
class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
"""Python multiprocessing-based multi-XPU executor"""
def _check_executor_parameters(self):
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
if mp_method != "spawn":
raise RuntimeError(
"XPU multiprocess executor only support spawn as mp method")
class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
MultiprocessingGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)

View File

@ -1,114 +0,0 @@
from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert (self.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not self.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
wrapper.init_worker(
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
)
self.driver_worker = wrapper.worker
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
assert (not execute_model_req.blocks_to_swap_in
and not execute_model_req.blocks_to_swap_out
and not execute_model_req.blocks_to_copy), (
"Cache operations are not supported for Neuron backend.")
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the Neuron backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the Neuron backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the Neuron backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Soft prompt is currently not supported by the Neuron backend.")
def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return

View File

@ -1,125 +0,0 @@
from typing import List, Set, Tuple
import openvino as ov
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert current_platform.is_openvino_cpu() or \
current_platform.is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
# Instantiate the worker and load the model to CPU.
self._init_worker()
def _init_worker(self):
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
wrapper.init_worker(
ov_core=ov.Core(),
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker = wrapper.worker
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
# is located on CPU memory but is referred as `gpu block`.
# Because we want to reuse the existing block management procedure.
device_blocks = num_gpu_blocks
swap_blocks = num_cpu_blocks
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the OPENVINO backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the OPENVINO backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the OPENVINO backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Soft prompt is currently not supported by the OPENVINO backend.")
def check_health(self) -> None:
# OpenVINOExecutor will always be healthy as long as
# it's running.
return
class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# OpenVINOExecutor will always be healthy as long as
# it's running.
return

View File

@ -1,24 +1,29 @@
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import msgspec
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.executor_base import (
DistributedExecutorBase) # yapf: disable
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
ray)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
if ray is not None:
from ray.actor import ActorHandle
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
ActorHandle = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -26,12 +31,29 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class RayGPUExecutor(DistributedGPUExecutor):
@dataclass
class RayWorkerMetaData:
"""
Metadata for a Ray worker.
The order of ray worker creation can be random,
and we need to reset the rank after creating all workers.
"""
worker: ActorHandle
created_rank: int
adjusted_rank: int = -1
ip: str = ""
class RayDistributedExecutor(DistributedExecutorBase):
uses_ray: bool = True
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
# v1 always uses the compiled DAG and SPMD worker.
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
@ -53,6 +75,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
@ -66,6 +89,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
self.use_v1 = envs.VLLM_USE_V1
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
@ -123,9 +153,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip = get_ip()
workers = []
rank = 0
worker_metadata: List[RayWorkerMetaData] = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
if not bundle.get(current_platform.ray_device_key, 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
@ -133,38 +164,51 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
workers.append(worker)
if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rank=rank)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
rank += 1
worker_ip_refs = [
worker.get_node_ip.remote() # type: ignore[attr-defined]
for worker in workers
]
worker_ips = ray.get(worker_ip_refs)
worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
])
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
if not self.use_ray_spmd_worker:
for i in range(len(workers)):
worker = workers[i]
worker_ip = worker_ips[i]
for i, each in enumerate(worker_metadata):
# find and remove the dummy worker from the list
worker = each.worker
worker_ip = each.ip
if self.driver_dummy_worker is None and worker_ip == driver_ip:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config)
workers.pop(i)
worker_ips.pop(i)
self.workers = workers
vllm_config=self.vllm_config, rank=0)
worker_metadata.pop(i)
break
else:
self.workers = workers
logger.debug("workers: %s", self.workers)
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
@ -176,9 +220,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
worker_to_ip = dict(zip(self.workers, worker_ips))
def sort_by_driver_then_worker_ip(worker):
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
@ -188,13 +230,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = worker_to_ip[worker]
return (ip != driver_ip, ip_counts[ip], ip)
ip = item.ip
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
sorted_worker_metadata = sorted(worker_metadata,
key=sort_by_driver_then_worker_ip)
start_rank = 0 if self.use_ray_spmd_worker else 1
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i + start_rank
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank
for item in sorted_worker_metadata
}
self._run_workers("adjust_rank", rerank_mapping)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
@ -235,21 +287,29 @@ class RayGPUExecutor(DistributedGPUExecutor):
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
all_args_to_update_environment_variables = [{
current_platform.device_control_env_var:
",".join(map(str, node_gpus[node_id])),
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
}, ) for (node_id, _) in worker_node_and_gpu_ids]
} for (node_id, _) in worker_node_and_gpu_ids]
for args in all_args_to_update_environment_variables:
# some carry-over env vars from the driver
# TODO: refactor platform-specific env vars
for name in [
"VLLM_ATTENTION_BACKEND",
"TPU_CHIPS_PER_HOST_BOUNDS",
"TPU_HOST_BOUNDS",
"VLLM_USE_V1",
"VLLM_TRACE_FUNCTION",
]:
if name in os.environ:
args[name] = os.environ[name]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
@ -265,14 +325,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self._run_workers("init_worker", all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
@ -332,9 +397,15 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
serialized_data = self.input_encoder.encode(execute_model_req)
if self.use_v1:
serialized_data = execute_model_req
else:
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
if self.use_v1:
output = outputs[0]
else:
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
@ -342,8 +413,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
@ -356,8 +425,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
@ -368,26 +435,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, first_worker_args_index, None)
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
worker.execute_method.remote(method, *args, **kwargs)
for worker in ray_workers
]
if async_run_tensor_parallel_workers_only:
@ -399,13 +453,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
self.driver_worker.execute_method(method, *args, **kwargs)
]
# Get the results of the ray workers.
@ -467,11 +517,18 @@ class RayGPUExecutor(DistributedGPUExecutor):
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
if self.use_v1:
outputs = [
worker.execute_model.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
else:
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
@ -497,17 +554,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
def __del__(self):
self.shutdown()
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
@ -568,5 +614,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
]
return await asyncio.gather(*coros)
def __del__(self):
self.shutdown()
def check_health(self) -> None:
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return

View File

@ -1,515 +0,0 @@
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayHPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def finish_measurements(self):
self._run_workers("finish_measurements")
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={'HPU': num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config)
else:
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` "
"environment variable, make sure it is unique for"
" each node.")
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
Args:
- async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, first_worker_args_index, None)
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _check_ray_adag_installation(self):
import pkg_resources
from packaging import version
required_version = version.parse("2.35")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
# TODO: update the constraint once we adapt to the backward
# incompatible API change from ray 2.36
if current_version != required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
adag_spec = importlib.util.find_spec(
"ray.experimental.compiled_dag_ref")
if adag_spec is None:
raise ValueError("Ray accelerated DAG is not installed. "
"Run `pip install ray[adag]` to install it.")
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_adag_installation()
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "auto"
outputs = [
output.with_type_hint(
TorchTensorType(transport=transport))
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
self.shutdown()
class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future
return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
def __del__(self):
self.shutdown()

View File

@ -1,343 +0,0 @@
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
Union)
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.executor.tpu_executor import TPUExecutor
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor):
uses_ray: bool = True
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def _init_executor(self) -> None:
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel TPU workers.
self._init_workers_ray(placement_group)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("TPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
if override_env:
worker.override_env_vars.remote(override_env)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config)
else:
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"TPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
node_workers = defaultdict(list)
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
if len(node_workers) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_blocks = self._run_workers("determine_num_available_blocks", )
num_tpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)

View File

@ -1,7 +1,7 @@
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import msgspec
@ -13,6 +13,10 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 1800
@ -95,6 +99,26 @@ try:
return output
def setup_device_if_necessary(self):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> "ModelRunnerOutput":
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
return output
def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)

View File

@ -1,40 +0,0 @@
import asyncio
from typing import List, Optional
import ray
import vllm.envs as envs
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import make_async
logger = init_logger(__name__)
class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
def _get_env_vars_to_be_updated(self):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (_, _) in worker_node_and_gpu_ids]
return all_args_to_update_environment_variables
class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
self.pp_locks: Optional[List[asyncio.Lock]] = None

View File

@ -1,142 +0,0 @@
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class TPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert not self.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not self.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if self.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.", self.model_config.dtype)
self.model_config.dtype = torch.bfloat16
# Instantiate the worker and load the model to the device.
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None,
) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=rank == 0,
)
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None,
):
if self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
else:
from vllm.worker.tpu_worker import TPUWorker
worker = TPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker."""
return self.driver_worker.determine_num_available_blocks()
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError(
"LoRA is currently not supported by the TPU backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is currently not supported by the TPU backend.")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is currently not supported by the TPU backend.")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"LoRA is currently not supported by the TPU backend.")
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the TPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the TPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Soft prompt is currently not supported by the TPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Soft prompt is currently not supported by the TPU backend.")
def check_health(self) -> None:
# TPUExecutor will always be healthy as long as it's running.
return
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
sexecute_model_req: ExecuteModelRequest,
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model
)(sexecute_model_req)
return output

View File

@ -0,0 +1,57 @@
from typing import Any, Dict, List, Optional, Tuple
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class UniProcExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
rank = 0
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
try:
func = getattr(self.driver_worker, method)
except AttributeError:
raise NotImplementedError(f"Method {method} is not implemented.") \
from None
answer = func(*args, **kwargs)
return [answer]
def check_health(self) -> None:
# UniProcExecutor will always be healthy as long as
# it's running.
return
UniProcExecutorAsync = UniProcExecutor

View File

@ -1,39 +0,0 @@
from typing import List, Optional, Union
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
logger = init_logger(__name__)
class XPUExecutor(GPUExecutor):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "xpu"
assert self.speculative_config is None, (
"Speculative decoding not yet supported for XPU backend")
GPUExecutor._init_executor(self)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output

View File

@ -1,3 +1,4 @@
import os
from typing import TYPE_CHECKING, Optional
import psutil
@ -105,6 +106,32 @@ class CpuPlatform(Platform):
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
assert vllm_config.device_config.device_type == "cpu"
#
# Environment variables for CPU executor
#
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
vllm_config.parallel_config.tensor_parallel_size)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")

View File

@ -139,6 +139,28 @@ class CudaPlatformBase(Platform):
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
world_size = parallel_config.world_size
tensor_parallel_size = parallel_config.tensor_parallel_size
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

View File

@ -35,6 +35,14 @@ class NeuronPlatform(Platform):
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"
assert (vllm_config.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not vllm_config.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
cache_config = vllm_config.cache_config
if cache_config:
# neuron needs block_size = max_model_len

View File

@ -66,9 +66,8 @@ class OpenVinoPlatform(Platform):
from vllm.utils import GiB_bytes
parallel_config = vllm_config.parallel_config
assert (
parallel_config.world_size == 1
), "OpenVINOExecutor only supports single CPU socket currently."
assert (parallel_config.world_size == 1
), "OpenVINO only supports single CPU socket currently."
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
@ -141,3 +140,10 @@ class OpenVinoPlatform(Platform):
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
assert vllm_config.device_config.device_type == "openvino"
assert vllm_config.lora_config is None, \
"OpenVINO backend doesn't support LoRA"
assert cls.is_openvino_cpu() or \
cls.is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"

View File

@ -72,6 +72,16 @@ class TpuPlatform(Platform):
assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":

View File

@ -78,17 +78,31 @@ class XPUPlatform(Platform):
raise NotImplementedError(
"XPU does not support speculative decoding")
if vllm_config.device_config is not None:
assert vllm_config.device_config.device_type == "xpu"
# check and update parallel config
parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None
and parallel_config.distributed_executor_backend != "ray"):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
if parallel_config.distributed_executor_backend is None:
parallel_config.distributed_executor_backend = "ray"
elif parallel_config.distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, setting it to ray instead.")
parallel_config.distributed_executor_backend = "ray"
elif parallel_config.distributed_executor_backend != "ray":
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "ray"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
@classmethod
def is_pin_memory_available(cls):

View File

@ -9,17 +9,15 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.worker.worker_base import DelegateWorkerBase
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer

View File

@ -16,10 +16,10 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import WorkerWrapperBase
from vllm.worker.worker_base import DelegateWorkerBase
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
@ -32,15 +32,12 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""
def __init__(self, *args, **kwargs):
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
def init_device(self) -> None:
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
@ -56,18 +53,6 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)
@torch.inference_mode()
def sampler_output(
self,

View File

@ -40,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
WorkerWrapperBase)
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
@ -64,8 +64,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
target_worker_config = copy.deepcopy(vllm_config)
target_worker_config.parallel_config.worker_cls =\
target_worker_config.parallel_config.sd_worker_cls
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
target_worker.init_worker(*args, **kwargs)
cls = resolve_obj_by_qualname(
target_worker_config.parallel_config.worker_cls)
target_worker = cls(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\

View File

@ -14,8 +14,9 @@ class Executor(ABC):
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor

View File

@ -246,9 +246,18 @@ class WorkerProc:
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
wrapper.init_worker(vllm_config, local_rank, rank,
distributed_init_method)
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: List[Dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker
pid = os.getpid()
@ -270,7 +279,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
self.worker.initialize()
self.worker.init_device()
self.worker.load_model()
@staticmethod

View File

@ -27,7 +27,7 @@ class UniprocExecutor(Executor):
self.observability_config = vllm_config.observability_config
self.worker: Worker = self._create_worker()
self.worker.initialize()
self.worker.init_device()
self.worker.load_model()
def _create_worker(

View File

@ -33,6 +33,7 @@ class Worker:
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
@ -75,7 +76,7 @@ class Worker:
else:
self.profiler = None
def initialize(self):
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow

View File

@ -2,6 +2,7 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import gc
import os
from typing import List, Optional, Set, Tuple, Type
@ -18,6 +19,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
@ -124,6 +126,70 @@ class HPUWorker(LocalOrDistributedWorkerBase):
def load_model(self):
self.model_runner.load_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
assert execute_model_req is not None
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if log_graph_compilation or log_cpu_fallbacks:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
return output
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many

View File

@ -8,6 +8,7 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
@ -25,6 +26,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = True,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
@ -37,7 +39,22 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.model_runner: NeuronModelRunner = NeuronModelRunner(
vllm_config=vllm_config)
self.is_driver_worker = True
self.is_driver_worker = is_driver_worker
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
assert execute_model_req is not None
assert (not execute_model_req.blocks_to_swap_in
and not execute_model_req.blocks_to_swap_out
and not execute_model_req.blocks_to_copy), (
"Cache operations are not supported for Neuron backend.")
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
return output
def init_device(self) -> None:
self.init_distributed_environment()
@ -103,13 +120,14 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment inited when TP/PP > 1
It has only one process to control multiple devices.
vLLM still needs the environment initialized when TP/PP > 1,
so we initialize a distributed environment with one process.
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
rank=0,
local_rank=0,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)

View File

@ -211,16 +211,14 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
ov_core: ov.Core,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
self.ov_core = ov_core
WorkerBase.__init__(self, vllm_config)
self.ov_core = ov.Core()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
@ -237,7 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self.model_runner = OpenVINOModelRunner(
self.ov_core,
vllm_config=self.vllm_config,
kv_cache_dtype=kv_cache_dtype,
kv_cache_dtype=self.vllm_config.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by

View File

@ -88,7 +88,6 @@ class WorkerBase(ABC):
if output is None:
return None
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
@ -119,6 +118,58 @@ class WorkerBase(ABC):
raise NotImplementedError
class DelegateWorkerBase(WorkerBase):
"""
A class that delegates all methods to another WorkerBase instance. This is
useful for creating a WorkerBase that wraps another WorkerBase instance,
e.g. speculative decoding.
"""
worker: WorkerBase
def __init__(
self,
*args,
**kwargs,
) -> None:
vllm_config: VllmConfig = kwargs.get("vllm_config")
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
self.worker = cls(*args, **kwargs)
def init_device(self) -> None:
self.worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
return self.worker.execute_model(execute_model_req)
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def __getattr__(self, attr):
return getattr(self.worker, attr)
class LoraNotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
@ -419,17 +470,31 @@ class WorkerWrapperBase:
def __init__(
self,
vllm_config: VllmConfig,
rank: int = 0,
) -> None:
self.rank = rank
self.vllm_config = vllm_config
trust_remote_code = vllm_config.model_config.trust_remote_code
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
"""
Adjust the rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rank of workers after we create all workers.
"""
if self.rank in rank_mapping:
self.rank = rank_mapping[self.rank]
def update_environment_variables(self, envs_list: List[Dict[str,
str]]) -> None:
envs = envs_list[self.rank]
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
@ -437,11 +502,12 @@ class WorkerWrapperBase:
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rank]
enable_trace_function_call_for_thread(self.vllm_config)
# see https://github.com/NVIDIA/nccl/issues/1234
@ -452,7 +518,7 @@ class WorkerWrapperBase:
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(*args, **kwargs)
self.worker = worker_class(**kwargs)
assert self.worker is not None
def execute_method(self, method: str, *args, **kwargs):