mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[core] platform agnostic executor via collective_rpc (#11256)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@ -14,21 +15,20 @@ class Mock:
|
||||
...
|
||||
|
||||
|
||||
class CustomGPUExecutor(GPUExecutor):
|
||||
class CustomUniExecutor(UniProcExecutor):
|
||||
|
||||
def execute_model(self, *args, **kwargs):
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
# Drop marker to show that this was ran
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().execute_model(*args, **kwargs)
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
class CustomGPUExecutorAsync(GPUExecutorAsync):
|
||||
|
||||
async def execute_model_async(self, *args, **kwargs):
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return await super().execute_model_async(*args, **kwargs)
|
||||
CustomUniExecutorAsync = CustomUniExecutor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model):
|
||||
engine_args = AsyncEngineArgs(model=model,
|
||||
distributed_executor_backend=Mock)
|
||||
AsyncLLMEngine.from_engine_args(engine_args)
|
||||
with pytest.raises(TypeError):
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||
AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path):
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model, distributed_executor_backend=CustomGPUExecutor)
|
||||
model=model, distributed_executor_backend=CustomUniExecutor)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path):
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
|
||||
model=model, distributed_executor_backend=CustomUniExecutorAsync)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
|
@ -6,16 +6,15 @@ from typing import Any, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
class DummyWorker:
|
||||
class DummyWorkerWrapper(WorkerWrapperBase):
|
||||
"""Dummy version of vllm.worker.worker.Worker"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
|
||||
sleep(0.05)
|
||||
|
||||
@ -28,9 +27,10 @@ class DummyWorker:
|
||||
|
||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
result_handler = ResultHandler()
|
||||
vllm_config = VllmConfig()
|
||||
workers = [
|
||||
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
|
||||
for rank in range(8)
|
||||
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
|
||||
rank) for rank in range(8)
|
||||
]
|
||||
|
||||
worker_monitor = WorkerMonitor(workers, result_handler)
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import socket
|
||||
from typing import AsyncIterator, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder():
|
||||
|
||||
|
||||
def test_bind_kv_cache_pp():
|
||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
|
||||
# this test runs with 1 GPU, but we simulate 2 GPUs
|
||||
cfg = VllmConfig(
|
||||
parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||
with set_current_vllm_config(cfg):
|
||||
from vllm.attention import Attention
|
||||
|
||||
|
@ -1294,8 +1294,11 @@ class ParallelConfig:
|
||||
from vllm.executor import ray_utils
|
||||
backend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if (current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size):
|
||||
if current_platform.is_neuron():
|
||||
# neuron uses single process to control multiple devices
|
||||
backend = "uni"
|
||||
elif (current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size):
|
||||
if not ray_found:
|
||||
raise ValueError("Unable to load Ray which is "
|
||||
"required for multi-node inference, "
|
||||
@ -1328,13 +1331,14 @@ class ParallelConfig:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.platforms import current_platform
|
||||
if self.distributed_executor_backend not in (
|
||||
"ray", "mp", None) and not (isinstance(
|
||||
"ray", "mp", "uni", None) and not (isinstance(
|
||||
self.distributed_executor_backend, type) and issubclass(
|
||||
self.distributed_executor_backend, ExecutorBase)):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
|
||||
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
|
||||
" subclass.")
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
|
@ -862,12 +862,14 @@ def init_model_parallel_group(
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
from vllm.platforms import current_platform
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_pynccl=current_platform.is_cuda_alike(),
|
||||
use_custom_allreduce=current_platform.is_cuda_alike()
|
||||
and use_custom_allreduce,
|
||||
use_tpu_communicator=True,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
|
@ -18,9 +18,7 @@ from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
@ -620,69 +618,9 @@ class AsyncLLMEngine(EngineClient):
|
||||
rt.new_requests_event.set()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(
|
||||
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
if isinstance(distributed_executor_backend, type):
|
||||
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
||||
executor_class = RayTPUExecutorAsync
|
||||
else:
|
||||
assert distributed_executor_backend is None
|
||||
from vllm.executor.tpu_executor import TPUExecutorAsync
|
||||
executor_class = TPUExecutorAsync
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutorAsync
|
||||
executor_class = CPUExecutorAsync
|
||||
elif engine_config.device_config.device_type == "hpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
|
||||
executor_class = RayHPUExecutorAsync
|
||||
else:
|
||||
from vllm.executor.hpu_executor import HPUExecutorAsync
|
||||
executor_class = HPUExecutorAsync
|
||||
elif engine_config.device_config.device_type == "openvino":
|
||||
assert distributed_executor_backend is None, (
|
||||
"Distributed execution is not supported with "
|
||||
"the OpenVINO backend.")
|
||||
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
|
||||
executor_class = OpenVINOExecutorAsync
|
||||
elif engine_config.device_config.device_type == "xpu":
|
||||
if distributed_executor_backend is None:
|
||||
from vllm.executor.xpu_executor import XPUExecutorAsync
|
||||
executor_class = XPUExecutorAsync
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||
executor_class = RayXPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_xpu_executor import (
|
||||
MultiprocessingXPUExecutorAsync)
|
||||
executor_class = MultiprocessingXPUExecutorAsync
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Not supported distributed execution model on XPU device.")
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutorAsync)
|
||||
executor_class = MultiprocessingGPUExecutorAsync
|
||||
else:
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
executor_class = GPUExecutorAsync
|
||||
return executor_class
|
||||
def _get_executor_cls(cls,
|
||||
engine_config: VllmConfig) -> Type[ExecutorBase]:
|
||||
return LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
@ -700,9 +638,6 @@ class AsyncLLMEngine(EngineClient):
|
||||
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
|
||||
if executor_class.uses_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
vllm_config=engine_config,
|
||||
@ -1242,23 +1177,12 @@ class AsyncLLMEngine(EngineClient):
|
||||
self.engine.remove_logger(logger_name=logger_name)
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
# using type instead of isinstance to check to avoid capturing
|
||||
# inherited classes
|
||||
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
|
||||
self.engine.model_executor.start_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("start_profile")
|
||||
self.engine.start_profile()
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
# using type instead of isinstance to check to avoid capturing
|
||||
# inherited classes
|
||||
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
|
||||
self.engine.model_executor.stop_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("stop_profile")
|
||||
self.engine.stop_profile()
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
self.engine.add_lora(lora_request)
|
||||
|
||||
|
||||
|
@ -28,8 +28,6 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
||||
@ -442,64 +440,26 @@ class LLMEngine:
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
if distributed_executor_backend.uses_ray: # type: ignore
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
executor_class = distributed_executor_backend
|
||||
elif engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
executor_class = NeuronExecutor
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
elif engine_config.parallel_config.world_size > 1:
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_tpu_executor import RayTPUExecutor
|
||||
executor_class = RayTPUExecutor
|
||||
else:
|
||||
assert distributed_executor_backend is None
|
||||
from vllm.executor.tpu_executor import TPUExecutor
|
||||
executor_class = TPUExecutor
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
elif engine_config.device_config.device_type == "hpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_hpu_executor import RayHPUExecutor
|
||||
executor_class = RayHPUExecutor
|
||||
else:
|
||||
from vllm.executor.hpu_executor import HPUExecutor
|
||||
executor_class = HPUExecutor
|
||||
elif engine_config.device_config.device_type == "openvino":
|
||||
from vllm.executor.openvino_executor import OpenVINOExecutor
|
||||
executor_class = OpenVINOExecutor
|
||||
elif engine_config.device_config.device_type == "xpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutor
|
||||
executor_class = RayXPUExecutor
|
||||
from vllm.executor.ray_distributed_executor import (
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':``
|
||||
# fork is not supported for xpu start new process.
|
||||
logger.error(
|
||||
"Both start methods (spawn and fork) have issue "
|
||||
"on XPU if you use mp backend, Please try ray instead.")
|
||||
else:
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
executor_class = XPUExecutor
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||
executor_class = RayGPUExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor)
|
||||
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
|
||||
"multiprocessing distributed executor backend does not "
|
||||
"support VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
executor_class = MultiprocessingGPUExecutor
|
||||
from vllm.executor.mp_distributed_executor import (
|
||||
MultiprocessingDistributedExecutor)
|
||||
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
|
||||
"multiprocessing distributed executor backend does not "
|
||||
"support VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
executor_class = MultiprocessingDistributedExecutor
|
||||
elif distributed_executor_backend == "uni":
|
||||
# JAX-style, single-process, multi-device executor.
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
executor_class = UniProcExecutor
|
||||
else:
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
executor_class = UniProcExecutor
|
||||
return executor_class
|
||||
|
||||
@classmethod
|
||||
@ -1845,27 +1805,17 @@ class LLMEngine:
|
||||
def list_prompt_adapters(self) -> List[int]:
|
||||
return self.model_executor.list_prompt_adapters()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.model_executor.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.model_executor.stop_profile()
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
# using type instead of isinstance to check to avoid capturing
|
||||
# inherited classes (MultiprocessingGPUExecutor)
|
||||
if type(self.model_executor) == GPUExecutor: # noqa: E721
|
||||
self.model_executor.start_profile()
|
||||
else:
|
||||
self.model_executor._run_workers("start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
# using type instead of isinstance to check to avoid capturing
|
||||
# inherited classes (MultiprocessingGPUExecutor)
|
||||
if type(self.model_executor) == GPUExecutor: # noqa: E721
|
||||
self.model_executor.stop_profile()
|
||||
else:
|
||||
self.model_executor._run_workers("stop_profile")
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
return self.tracer is not None
|
||||
|
||||
|
@ -20,7 +20,6 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
# yapf: enable
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -356,16 +355,10 @@ class MQLLMEngine:
|
||||
self._errored_with = e
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if type(self.engine.model_executor) is GPUExecutor:
|
||||
self.engine.model_executor.start_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("start_profile")
|
||||
self.engine.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if type(self.engine.model_executor) is GPUExecutor:
|
||||
self.engine.model_executor.stop_profile()
|
||||
else:
|
||||
self.engine.model_executor._run_workers("stop_profile")
|
||||
self.engine.stop_profile()
|
||||
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
|
@ -1,299 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_distributed_init_method, get_open_port, make_async
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "cpu"
|
||||
|
||||
#
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Intel OpenMP setting
|
||||
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
||||
if "libiomp5.so" in ld_prealod_str:
|
||||
# The time(milliseconds) that a thread should wait after
|
||||
# completing the execution of a parallel region, before sleeping.
|
||||
os.environ['KMP_BLOCKTIME'] = "1"
|
||||
# Prevents the CPU to run into low performance state
|
||||
os.environ['KMP_TPAUSE'] = "0"
|
||||
# Provides fine granularity parallelism
|
||||
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
|
||||
|
||||
# To hint IPEX uses shared memory based AllReduce
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||
self.parallel_config.tensor_parallel_size)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
ip = "127.0.0.1"
|
||||
port = get_open_port()
|
||||
self.distributed_init_method = get_distributed_init_method(ip, port)
|
||||
|
||||
is_async = isinstance(self, CPUExecutorAsync)
|
||||
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
result_handler = ResultHandler()
|
||||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||
self.workers = []
|
||||
|
||||
if is_async:
|
||||
self.workers = [
|
||||
ProcessWorkerWrapper(
|
||||
result_handler,
|
||||
partial(
|
||||
self._create_worker,
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
)) for rank in range(0, world_size)
|
||||
]
|
||||
self.driver_worker = self.workers[0]
|
||||
self.workers = self.workers[1:]
|
||||
self.driver_method_invoker = _async_driver_method_invoker
|
||||
else:
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_method_invoker = _driver_method_invoker
|
||||
|
||||
if world_size != 1:
|
||||
self.workers = [
|
||||
ProcessWorkerWrapper(
|
||||
result_handler,
|
||||
partial(
|
||||
self._create_worker,
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
)) for rank in range(1, world_size)
|
||||
]
|
||||
|
||||
self.worker_monitor = None
|
||||
if world_size != 1 or is_async:
|
||||
if is_async:
|
||||
async_worker_list = self.workers + [self.driver_worker]
|
||||
else:
|
||||
async_worker_list = self.workers
|
||||
self.worker_monitor = WorkerMonitor(async_worker_list,
|
||||
result_handler)
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model")
|
||||
|
||||
def _create_worker(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
):
|
||||
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
assert self.distributed_init_method is not None
|
||||
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
wrapper.init_worker(**kwargs)
|
||||
|
||||
return wrapper.worker
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than
|
||||
blocking on the results.
|
||||
"""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
# Start the workers first.
|
||||
worker_outputs = [
|
||||
worker.execute_method(method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
# Just return futures
|
||||
return worker_outputs
|
||||
|
||||
driver_worker_output = self.driver_method_invoker(
|
||||
self.driver_worker, method, *args, **kwargs)
|
||||
|
||||
# Get the results of the workers.
|
||||
return [driver_worker_output
|
||||
] + [output.get() for output in worker_outputs]
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_method_invoker(self.driver_worker,
|
||||
"determine_num_available_blocks")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
|
||||
# referred as `gpu block`. Because we want to reuse the existing block
|
||||
# management procedure.
|
||||
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
||||
|
||||
self._run_workers("initialize_cache",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if (self.parallel_config.tensor_parallel_size > 1
|
||||
and self.parallel_worker_tasks is None):
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_remote_workers_only=True,
|
||||
)
|
||||
output = self.driver_method_invoker(self.driver_worker,
|
||||
"execute_model", execute_model_req)
|
||||
return output
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
"""
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
self.driver_method_invoker(self.driver_worker, "execute_model", None)
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return all(self._run_workers("add_lora", lora_request))
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return all(self._run_workers("remove_lora", lora_id))
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self._run_workers(
|
||||
"pin_lora",
|
||||
lora_id=lora_id,
|
||||
))
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_method_invoker(self.driver_worker, "list_loras")
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return all(
|
||||
self._run_workers(
|
||||
"add_prompt_adapter",
|
||||
prompt_adapter_request,
|
||||
))
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return all(
|
||||
self._run_workers(
|
||||
"remove_prompt_adapter",
|
||||
prompt_adapter_id,
|
||||
))
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.driver_method_invoker(self.driver_worker,
|
||||
"list_prompt_adapters")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return all(self._run_workers(
|
||||
"pin_prompt_adapter",
|
||||
prompt_adapter_id,
|
||||
))
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
|
||||
):
|
||||
raise RuntimeError("Worker processes are not running")
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
None)) is not None:
|
||||
worker_monitor.close()
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
for result in parallel_worker_tasks:
|
||||
result.get()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.driver_method_invoker(self.driver_worker, "start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.driver_method_invoker(self.driver_worker, "stop_profile")
|
||||
|
||||
|
||||
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = await make_async(self.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
self.check_health()
|
||||
|
||||
|
||||
def _driver_method_invoker(driver, method: str, *args, **kwargs):
|
||||
return getattr(driver, method)(*args, **kwargs)
|
||||
|
||||
|
||||
def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
|
||||
return driver.execute_method(method, *args, **kwargs).get()
|
@ -1,212 +0,0 @@
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DistributedGPUExecutor(GPUExecutor):
|
||||
"""Abstract superclass of multi-GPU executor implementations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||
# Updated by implementations that require additional args to be passed
|
||||
# to the _run_workers execute_model call
|
||||
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
|
||||
Returns:
|
||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache in all workers.
|
||||
"""
|
||||
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
|
||||
self.model_config.max_model_len)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
self.model_config.max_model_len, max_concurrency)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self._run_workers("initialize_cache",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_tensor_parallel_workers_only=True,
|
||||
**self.extra_execute_model_run_workers_kwargs)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
driver_outputs = self._driver_execute_model(execute_model_req)
|
||||
assert driver_outputs is not None
|
||||
return driver_outputs
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
self._driver_execute_model(execute_model_req=None)
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"add_lora",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"remove_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"pin_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self._run_workers("save_sharded_state",
|
||||
path=path,
|
||||
pattern=pattern,
|
||||
max_size=max_size)
|
||||
|
||||
@abstractmethod
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution loop
|
||||
running in each of the remote workers. In this case, this method
|
||||
returns None. Otherwise, this method returns the model output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
await self._driver_execute_model_async()
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
await parallel_worker_tasks
|
||||
|
||||
@abstractmethod
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _start_worker_execution_loop(self):
|
||||
"""Run execution loop on all workers. It guarantees all workers run
|
||||
the loop or None of them is running the loop. Loop can be stopped by
|
||||
`stop_remote_worker_execution_loop`.
|
||||
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
||||
raise NotImplementedError
|
@ -1,18 +1,24 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ExecutorBase(ABC):
|
||||
"""Base class for all executors.
|
||||
|
||||
An executor is responsible for executing the model on a specific device
|
||||
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
|
||||
An executor is responsible for executing the model on one device,
|
||||
or it can be a distributed executor
|
||||
that can execute the model on multiple devices.
|
||||
"""
|
||||
|
||||
@ -40,6 +46,20 @@ class ExecutorBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
"""
|
||||
The main interface of the executor to run a method on all workers,
|
||||
with homogeneous arguments.
|
||||
If the args are heterogeneous, then we can pack them into a list,
|
||||
and unpack them in the method of every worker, because every worker
|
||||
knows their own rank.
|
||||
"""
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
@ -53,58 +73,113 @@ class ExecutorBase(ABC):
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
results = self.collective_rpc("determine_num_available_blocks")
|
||||
a = min([r[0] for r in results])
|
||||
b = min([r[1] for r in results])
|
||||
return a, b
|
||||
|
||||
@abstractmethod
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache with the given size in blocks.
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
For V1 compatibility.
|
||||
"""
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 workers.
|
||||
logger.info("# %s blocks: %d, # CPU blocks: %d",
|
||||
current_platform.dispatch_key, num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
|
||||
self.model_config.max_model_len)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
self.model_config.max_model_len, max_concurrency)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.collective_rpc("initialize_cache",
|
||||
args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(execute_model_req, ))
|
||||
return output[0]
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
"""Releases parallel workers from model loop."""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("add_lora", args=(lora_request, )))
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("remove_lora", args=(lora_id, )))
|
||||
|
||||
@abstractmethod
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError # type: ignore
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return all(self.collective_rpc("pin_lora", args=(lora_id, )))
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
sets = self.collective_rpc("list_loras")
|
||||
for s in sets:
|
||||
assert s == sets[0], "All workers should have the same LORAs."
|
||||
return sets[0]
|
||||
|
||||
@abstractmethod
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
raise NotImplementedError
|
||||
assert prompt_adapter_request.prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("add_prompt_adapter",
|
||||
args=(prompt_adapter_request, )))
|
||||
|
||||
@abstractmethod
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("remove_prompt_adapter",
|
||||
args=(prompt_adapter_id, )))
|
||||
|
||||
@abstractmethod
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError # type: ignore
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return all(
|
||||
self.collective_rpc("pin_prompt_adapter",
|
||||
args=(prompt_adapter_id, )))
|
||||
|
||||
@abstractmethod
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
sets = self.collective_rpc("list_prompt_adapters")
|
||||
for s in sets:
|
||||
assert (s == sets[0]
|
||||
), "All workers should have the same prompt adapters."
|
||||
return sets[0]
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.collective_rpc("start_profile")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.collective_rpc("stop_profile")
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.collective_rpc("save_sharded_state",
|
||||
kwargs=dict(path=path,
|
||||
pattern=pattern,
|
||||
max_size=max_size))
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
@ -119,15 +194,12 @@ class ExecutorBase(ABC):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class ExecutorAsyncBase(ExecutorBase):
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
output = await make_async(self.execute_model)(execute_model_req)
|
||||
return output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Releases parallel workers from model loop."""
|
||||
@ -137,3 +209,128 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
self.check_health()
|
||||
|
||||
|
||||
class DistributedExecutorBase(ExecutorBase):
|
||||
"""Abstract superclass of distributed executor implementations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
# TODO: unify into collective_rpc
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_tensor_parallel_workers_only=True)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
driver_outputs = self._driver_execute_model(execute_model_req)
|
||||
assert driver_outputs is not None
|
||||
return driver_outputs
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
self._driver_execute_model(execute_model_req=None)
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
@abstractmethod
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution loop
|
||||
running in each of the remote workers. In this case, this method
|
||||
returns None. Otherwise, this method returns the model output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
return self._run_workers(method, *args, **(kwargs or {}))
|
||||
|
||||
@abstractmethod
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
|
||||
# TODO: simplify and merge with collective_rpc
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
await self._driver_execute_model_async()
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
await parallel_worker_tasks
|
||||
|
||||
@abstractmethod
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Execute the model asynchronously in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _start_worker_execution_loop(self):
|
||||
"""Run execution loop on all workers. It guarantees all workers run
|
||||
the loop or None of them is running the loop. Loop can be stopped by
|
||||
`stop_remote_worker_execution_loop`.
|
||||
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
||||
raise NotImplementedError
|
||||
|
@ -1,145 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(**kwargs):
|
||||
vllm_config = kwargs.get("vllm_config")
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||
wrapper.init_worker(**kwargs)
|
||||
return wrapper.worker
|
||||
|
||||
|
||||
class GPUExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
|
||||
def _create_worker(self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
return create_worker(**self._get_worker_kwargs(
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method))
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
|
||||
self.model_config.max_model_len)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
self.model_config.max_model_len, max_concurrency)
|
||||
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
assert prompt_adapter_request.prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.driver_worker.list_prompt_adapters()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.driver_worker.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.driver_worker.stop_profile()
|
||||
|
||||
|
||||
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req)
|
||||
return output
|
@ -1,202 +0,0 @@
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HPUExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model."""
|
||||
self._init_worker()
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
|
||||
def _create_worker(self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return wrapper.worker
|
||||
|
||||
def _init_worker(self):
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler
|
||||
with HabanaMemoryProfiler() as cache_init_m:
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
|
||||
logger.info(msg)
|
||||
|
||||
def finish_measurements(self):
|
||||
self.driver_worker.finish_measurements()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
|
||||
log_graph_compilation_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
|
||||
log_graph_compilation = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
|
||||
'0') != '0' or log_graph_compilation_all
|
||||
log_cpu_fallbacks_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
|
||||
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
|
||||
'0') != '0' or log_cpu_fallbacks_all
|
||||
if log_graph_compilation or log_cpu_fallbacks:
|
||||
from habana_frameworks.torch.hpu.metrics import metric_localcontext
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
is_prompt = any([
|
||||
seq_group_metadata.is_prompt
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
])
|
||||
max_context_len = max([
|
||||
max([
|
||||
len(v.prompt_token_ids) + len(v.output_token_ids)
|
||||
for v in seq_group_metadata.seq_data.values()
|
||||
]) for seq_group_metadata in seq_group_metadata_list
|
||||
]) # whoa, that's some spicy stuff right here
|
||||
max_num_blocks = (
|
||||
(max_context_len - 1) // self.cache_config.block_size) + 1
|
||||
input_stats = (f'is_prompt: {is_prompt}, '
|
||||
f'num_seqs: {len(seq_group_metadata_list)}, '
|
||||
f'max_context_len: {max_context_len}, '
|
||||
f'max_num_blocks {max_num_blocks}')
|
||||
gc_ctx = metric_localcontext(
|
||||
"graph_compilation"
|
||||
) if log_graph_compilation else contextlib.nullcontext()
|
||||
cpu_fallback_ctx = metric_localcontext(
|
||||
"cpu_fallback"
|
||||
) if log_cpu_fallbacks else contextlib.nullcontext()
|
||||
with gc_ctx as gc_local_metric, \
|
||||
cpu_fallback_ctx as cpu_fallback_local_metric:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
|
||||
) or log_graph_compilation_all:
|
||||
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
|
||||
f"{gc_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
|
||||
0) or log_cpu_fallbacks_all:
|
||||
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
|
||||
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
|
||||
return output
|
||||
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.driver_worker.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.driver_worker.stop_profile()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.driver_worker.shutdown_inc()
|
||||
|
||||
|
||||
class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
@ -1,32 +1,26 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.gpu_executor import create_worker
|
||||
from vllm.executor.executor_base import DistributedExecutorBase
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
|
||||
set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||
get_distributed_init_method, get_open_port, make_async,
|
||||
update_environment_variables)
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
"""Python multiprocessing-based multi-GPU executor"""
|
||||
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
"""Python multiprocessing-based distributed executor"""
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self._check_executor_parameters()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
@ -55,15 +49,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
else:
|
||||
result_handler = ResultHandler()
|
||||
for rank in range(1, world_size):
|
||||
worker = ProcessWorkerWrapper(
|
||||
result_handler,
|
||||
partial(
|
||||
create_worker,
|
||||
**self._get_worker_kwargs(
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)))
|
||||
worker = ProcessWorkerWrapper(result_handler,
|
||||
WorkerWrapperBase,
|
||||
self.vllm_config, rank)
|
||||
self.workers.append(worker)
|
||||
if rank % tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
@ -77,32 +65,30 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
# Set up signal handlers to shutdown the executor cleanly
|
||||
# sometimes gc does not work well
|
||||
|
||||
self.driver_worker = self._create_worker(
|
||||
distributed_init_method=distributed_init_method)
|
||||
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
|
||||
|
||||
all_kwargs = []
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
for i in range(world_size):
|
||||
local_rank = i
|
||||
rank = i
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
@ -172,15 +158,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
for result in parallel_worker_tasks:
|
||||
result.get()
|
||||
|
||||
|
||||
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
|
||||
DistributedGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
@ -12,6 +12,7 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import _check_multiproc_method, get_mp_context
|
||||
@ -147,7 +148,8 @@ class ProcessWorkerWrapper:
|
||||
for handling single-node multi-GPU tensor parallel."""
|
||||
|
||||
def __init__(self, result_handler: ResultHandler,
|
||||
worker_factory: Callable[[], Any]) -> None:
|
||||
worker_factory: Callable[[VllmConfig, int], Any],
|
||||
vllm_config: VllmConfig, rank: int) -> None:
|
||||
self.mp = get_mp_context()
|
||||
self._task_queue = self.mp.Queue()
|
||||
self.result_queue = result_handler.result_queue
|
||||
@ -159,6 +161,8 @@ class ProcessWorkerWrapper:
|
||||
worker_factory=worker_factory,
|
||||
task_queue=self._task_queue,
|
||||
result_queue=self.result_queue,
|
||||
vllm_config=vllm_config,
|
||||
rank=rank,
|
||||
),
|
||||
daemon=True)
|
||||
|
||||
@ -199,9 +203,11 @@ class ProcessWorkerWrapper:
|
||||
|
||||
|
||||
def _run_worker_process(
|
||||
worker_factory: Callable[[], Any],
|
||||
worker_factory: Callable[[VllmConfig, int], Any],
|
||||
task_queue: Queue,
|
||||
result_queue: Queue,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
) -> None:
|
||||
"""Worker process event loop"""
|
||||
|
||||
@ -212,7 +218,7 @@ def _run_worker_process(
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
# Initialize worker
|
||||
worker = worker_factory()
|
||||
worker = worker_factory(vllm_config, rank)
|
||||
del worker_factory
|
||||
|
||||
# Accept tasks from the engine in task_queue
|
||||
|
@ -1,26 +0,0 @@
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
|
||||
"""Python multiprocessing-based multi-XPU executor"""
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
if mp_method != "spawn":
|
||||
raise RuntimeError(
|
||||
"XPU multiprocess executor only support spawn as mp method")
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
|
||||
MultiprocessingGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
@ -1,114 +0,0 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert (self.lora_config is
|
||||
None), "LoRA is not supported for Neuron backend."
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for Neuron backend."
|
||||
|
||||
# Instantiate the worker and load the model to the device.
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
wrapper.init_worker(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
self.driver_worker = wrapper.worker
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
assert (not execute_model_req.blocks_to_swap_in
|
||||
and not execute_model_req.blocks_to_swap_out
|
||||
and not execute_model_req.blocks_to_copy), (
|
||||
"Cache operations are not supported for Neuron backend.")
|
||||
assert execute_model_req.num_lookahead_slots == 0, (
|
||||
"lookahead not supported for Neuron backend.")
|
||||
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the Neuron backend.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the Neuron backend.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the Neuron backend.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the Neuron backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# NeuronExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
|
||||
class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
# NeuronExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
@ -1,125 +0,0 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import openvino as ov
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenVINOExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "openvino"
|
||||
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
|
||||
assert current_platform.is_openvino_cpu() or \
|
||||
current_platform.is_openvino_gpu(), \
|
||||
"OpenVINO backend supports only CPU and GPU devices"
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
wrapper.init_worker(
|
||||
ov_core=ov.Core(),
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self.driver_worker = wrapper.worker
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker."""
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
|
||||
# is located on CPU memory but is referred as `gpu block`.
|
||||
# Because we want to reuse the existing block management procedure.
|
||||
device_blocks = num_gpu_blocks
|
||||
swap_blocks = num_cpu_blocks
|
||||
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
|
||||
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the OPENVINO backend.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the OPENVINO backend.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the OPENVINO backend.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the OPENVINO backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# OpenVINOExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
|
||||
class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
# OpenVINOExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
@ -1,24 +1,29 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.executor_base import (
|
||||
DistributedExecutorBase) # yapf: disable
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
||||
ray)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.actor import ActorHandle
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
else:
|
||||
ActorHandle = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -26,12 +31,29 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayGPUExecutor(DistributedGPUExecutor):
|
||||
@dataclass
|
||||
class RayWorkerMetaData:
|
||||
"""
|
||||
Metadata for a Ray worker.
|
||||
The order of ray worker creation can be random,
|
||||
and we need to reset the rank after creating all workers.
|
||||
"""
|
||||
worker: ActorHandle
|
||||
created_rank: int
|
||||
adjusted_rank: int = -1
|
||||
ip: str = ""
|
||||
|
||||
|
||||
class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
if envs.VLLM_USE_V1:
|
||||
# v1 always uses the compiled DAG and SPMD worker.
|
||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
@ -53,6 +75,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||
|
||||
assert self.uses_ray
|
||||
initialize_ray_cluster(self.parallel_config)
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
@ -66,6 +89,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(
|
||||
Optional[List[SamplerOutput]])
|
||||
self.use_v1 = envs.VLLM_USE_V1
|
||||
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(
|
||||
self.driver_worker.execute_method)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
@ -123,9 +153,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
workers = []
|
||||
rank = 0
|
||||
worker_metadata: List[RayWorkerMetaData] = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
if not bundle.get(current_platform.ray_device_key, 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
@ -133,38 +164,51 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
workers.append(worker)
|
||||
if current_platform.ray_device_key == "GPU":
|
||||
# NV+AMD GPUs, and Intel XPUs
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rank=rank)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
resources={current_platform.ray_device_key: num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rank=rank)
|
||||
worker_metadata.append(
|
||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
rank += 1
|
||||
|
||||
worker_ip_refs = [
|
||||
worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
for worker in workers
|
||||
]
|
||||
worker_ips = ray.get(worker_ip_refs)
|
||||
worker_ips = ray.get([
|
||||
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
for each in worker_metadata
|
||||
])
|
||||
|
||||
for each, ip in zip(worker_metadata, worker_ips):
|
||||
each.ip = ip
|
||||
|
||||
if not self.use_ray_spmd_worker:
|
||||
for i in range(len(workers)):
|
||||
worker = workers[i]
|
||||
worker_ip = worker_ips[i]
|
||||
for i, each in enumerate(worker_metadata):
|
||||
# find and remove the dummy worker from the list
|
||||
worker = each.worker
|
||||
worker_ip = each.ip
|
||||
if self.driver_dummy_worker is None and worker_ip == driver_ip:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config)
|
||||
workers.pop(i)
|
||||
worker_ips.pop(i)
|
||||
self.workers = workers
|
||||
vllm_config=self.vllm_config, rank=0)
|
||||
worker_metadata.pop(i)
|
||||
break
|
||||
else:
|
||||
self.workers = workers
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("workers: %s", worker_metadata)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
@ -176,9 +220,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
worker_to_ip = dict(zip(self.workers, worker_ips))
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
@ -188,13 +230,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = worker_to_ip[worker]
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
ip = item.ip
|
||||
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
sorted_worker_metadata = sorted(worker_metadata,
|
||||
key=sort_by_driver_then_worker_ip)
|
||||
start_rank = 0 if self.use_ray_spmd_worker else 1
|
||||
for i, item in enumerate(sorted_worker_metadata):
|
||||
item.adjusted_rank = i + start_rank
|
||||
self.workers = [item.worker for item in sorted_worker_metadata]
|
||||
rerank_mapping = {
|
||||
item.created_rank: item.adjusted_rank
|
||||
for item in sorted_worker_metadata
|
||||
}
|
||||
self._run_workers("adjust_rank", rerank_mapping)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
@ -235,21 +287,29 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
" each node.")
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
all_args_to_update_environment_variables = [{
|
||||
current_platform.device_control_env_var:
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
**({
|
||||
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
|
||||
} if envs.VLLM_ATTENTION_BACKEND is not None else {})
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
} for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
for args in all_args_to_update_environment_variables:
|
||||
# some carry-over env vars from the driver
|
||||
# TODO: refactor platform-specific env vars
|
||||
for name in [
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS",
|
||||
"TPU_HOST_BOUNDS",
|
||||
"VLLM_USE_V1",
|
||||
"VLLM_TRACE_FUNCTION",
|
||||
]:
|
||||
if name in os.environ:
|
||||
args[name] = os.environ[name]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=self._get_env_vars_to_be_updated())
|
||||
self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
@ -265,14 +325,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
all_kwargs = []
|
||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
@ -332,9 +397,15 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
if self.use_v1:
|
||||
serialized_data = execute_model_req
|
||||
else:
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
if self.use_v1:
|
||||
output = outputs[0]
|
||||
else:
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
@ -342,8 +413,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
method: str,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
@ -356,8 +425,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
@ -368,26 +435,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers) if not \
|
||||
async_run_tensor_parallel_workers_only \
|
||||
else len(self.non_driver_workers)
|
||||
# If using SPMD worker, all workers are the same, so we should execute
|
||||
# the args on all workers. Otherwise, we skip the first worker's args
|
||||
# because those args will go to the driver worker.
|
||||
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, first_worker_args_index, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, first_worker_args_index, None)
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
for worker in ray_workers
|
||||
]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
@ -399,13 +453,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# so we only explicitly execute on the driver worker if using a
|
||||
# non-SPMD worker class.
|
||||
if not self.use_ray_spmd_worker:
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(method, *driver_args,
|
||||
**driver_kwargs)
|
||||
self.driver_worker.execute_method(method, *args, **kwargs)
|
||||
]
|
||||
|
||||
# Get the results of the ray workers.
|
||||
@ -467,11 +517,18 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
||||
# Each PP worker takes in the output of the previous PP worker,
|
||||
# and the TP group executes in SPMD fashion.
|
||||
outputs = [
|
||||
worker.execute_model_spmd.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
if self.use_v1:
|
||||
outputs = [
|
||||
worker.execute_model.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
worker.execute_model_spmd.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
|
||||
last_pp_rank = len(self.pp_tp_workers) - 1
|
||||
if pp_rank < last_pp_rank:
|
||||
@ -497,17 +554,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(
|
||||
self.driver_worker.execute_method)
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
@ -568,5 +614,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
def check_health(self) -> None:
|
||||
# Assume that the Ray workers are healthy.
|
||||
# TODO: check the health of the Ray workers
|
||||
return
|
@ -1,515 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayHPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
# Currently, this requires USE_RAY_SPMD_WORKER=True.
|
||||
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
# If the env var is set, then we do not distinguish between the
|
||||
# "driver worker" vs other workers. Also, the rank 0 worker will
|
||||
# be executed in a remote Ray worker. Currently this requires
|
||||
# USE_RAY_COMPILED_DAG=True.
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if self.use_ray_compiled_dag:
|
||||
assert self.use_ray_spmd_worker, (
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if self.use_ray_spmd_worker:
|
||||
# TODO: Support SPMD worker for non-DAG Ray executor.
|
||||
assert self.use_ray_compiled_dag, (
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||
|
||||
assert self.uses_ray
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(
|
||||
Optional[List[SamplerOutput]])
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.forward_dag = None
|
||||
|
||||
def finish_measurements(self):
|
||||
self._run_workers("finish_measurements")
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("HPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
resources={'HPU': num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
else:
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = ray.get(worker.get_node_ip.remote())
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||
) # type: ignore
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP` "
|
||||
"environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=all_args_to_update_environment_variables)
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return super().execute_model(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
Args:
|
||||
- async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
"async_run_tensor_parallel_workers_only is not supported for "
|
||||
"spmd mode.")
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers) if not \
|
||||
async_run_tensor_parallel_workers_only \
|
||||
else len(self.non_driver_workers)
|
||||
# If using SPMD worker, all workers are the same, so we should execute
|
||||
# the args on all workers. Otherwise, we skip the first worker's args
|
||||
# because those args will go to the driver worker.
|
||||
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, first_worker_args_index, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, first_worker_args_index, None)
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_worker_output = []
|
||||
# In SPMD mode, the driver worker is the same as any other worker,
|
||||
# so we only explicitly execute on the driver worker if using a
|
||||
# non-SPMD worker class.
|
||||
if not self.use_ray_spmd_worker:
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(method, *driver_args,
|
||||
**driver_kwargs)
|
||||
]
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _check_ray_adag_installation(self):
|
||||
import pkg_resources
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.35")
|
||||
current_version = version.parse(
|
||||
pkg_resources.get_distribution("ray").version)
|
||||
# TODO: update the constraint once we adapt to the backward
|
||||
# incompatible API change from ray 2.36
|
||||
if current_version != required_version:
|
||||
raise ValueError(f"Ray version {required_version} is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
import importlib.util
|
||||
adag_spec = importlib.util.find_spec(
|
||||
"ray.experimental.compiled_dag_ref")
|
||||
if adag_spec is None:
|
||||
raise ValueError("Ray accelerated DAG is not installed. "
|
||||
"Run `pip install ray[adag]` to install it.")
|
||||
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
assert self.parallel_config.use_ray
|
||||
self._check_ray_adag_installation()
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
from ray.experimental.channel.torch_tensor_type import TorchTensorType
|
||||
|
||||
with InputNode() as input_data:
|
||||
# Example DAG: PP=2, TP=4
|
||||
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
|
||||
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
|
||||
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
|
||||
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
|
||||
|
||||
# All workers in the first TP group will take in the
|
||||
# ExecuteModelRequest as input.
|
||||
outputs = [input_data for _ in self.pp_tp_workers[0]]
|
||||
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
||||
# Each PP worker takes in the output of the previous PP worker,
|
||||
# and the TP group executes in SPMD fashion.
|
||||
outputs = [
|
||||
worker.execute_model_spmd.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
|
||||
last_pp_rank = len(self.pp_tp_workers) - 1
|
||||
if pp_rank < last_pp_rank:
|
||||
# Specify how intermediate tensors should be passed
|
||||
# between pp stages, no need to specify for the last
|
||||
# pp stage.
|
||||
transport = "auto"
|
||||
outputs = [
|
||||
output.with_type_hint(
|
||||
TorchTensorType(transport=transport))
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
forward_dag = MultiOutputNode(outputs)
|
||||
|
||||
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(
|
||||
self.driver_worker.execute_method)
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return await super().execute_model_async(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
dag_future = await self.forward_dag.execute_async(serialized_data)
|
||||
outputs = await dag_future
|
||||
return self.output_decoder.decode(outputs[0])
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if not self.tp_driver_workers:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
# We create the locks here to avoid creating them in the constructor
|
||||
# which uses a different asyncio loop.
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
|
||||
"execute_model", execute_model_req))
|
||||
]
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(driver_worker.execute_method.remote,
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model", execute_model_req)))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
@ -1,343 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.executor.tpu_executor import TPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayTPUExecutor(TPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||
# Updated by implementations that require additional args to be passed
|
||||
# to the _run_workers execute_model call
|
||||
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel TPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("TPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
# GKE does not fetch environment information from metadata server
|
||||
# and instead sets these from within the Ray process. Therefore we
|
||||
# need to override the Ray environment variables manually.
|
||||
override_env = {}
|
||||
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
|
||||
override_env.update({
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS":
|
||||
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
|
||||
})
|
||||
if "TPU_HOST_BOUNDS" in os.environ:
|
||||
override_env.update(
|
||||
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
resources={"TPU": 1},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
if override_env:
|
||||
worker.override_env_vars.remote(override_env)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any TPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"TPU node.")
|
||||
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = ray.get(worker.get_node_ip.remote())
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
# Get the set of TPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||
) # type: ignore
|
||||
|
||||
node_workers = defaultdict(list)
|
||||
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for _ in worker_node_and_gpu_ids]
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=all_args_to_update_environment_variables)
|
||||
|
||||
if len(node_workers) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
- async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than blocking
|
||||
on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, 1, None)
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = self.driver_worker.execute_method(
|
||||
method, *driver_args, **driver_kwargs)
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||
num_tpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
return num_tpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self._run_workers("initialize_cache",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_remote_workers_only=True,
|
||||
**self.extra_execute_model_run_workers_kwargs)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return self._driver_execute_model(execute_model_req)
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
self._driver_execute_model()
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||
|
||||
|
||||
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
await self._driver_execute_model_async()
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
# (this will raise otherwise)
|
||||
await parallel_worker_tasks
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -13,6 +13,10 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
PG_WAIT_TIMEOUT = 1800
|
||||
|
||||
@ -95,6 +99,26 @@ try:
|
||||
|
||||
return output
|
||||
|
||||
def setup_device_if_necessary(self):
|
||||
# TODO(swang): This is needed right now because Ray CG executes
|
||||
# on a background thread, so we need to reset torch's current
|
||||
# device.
|
||||
# We can remove this API after it is fixed in compiled graph.
|
||||
import torch
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> "ModelRunnerOutput":
|
||||
self.setup_device_if_necessary()
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
output = self.worker.model_runner.execute_model(scheduler_output)
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: Dict[str, str]):
|
||||
os.environ.update(vars)
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import ray
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
|
||||
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for (_, _) in worker_node_and_gpu_ids]
|
||||
return all_args_to_update_environment_variables
|
||||
|
||||
|
||||
class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
@ -1,142 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert not self.scheduler_config.chunked_prefill_enabled, (
|
||||
"Chunked prefill is not yet supported for TPU backend")
|
||||
assert not self.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
if self.model_config.dtype in (torch.float16, torch.float32):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.", self.model_config.dtype)
|
||||
self.model_config.dtype = torch.bfloat16
|
||||
|
||||
# Instantiate the worker and load the model to the device.
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
|
||||
def _create_worker(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
):
|
||||
if self.scheduler_config.is_multi_step:
|
||||
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
|
||||
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
|
||||
local_rank, rank, distributed_init_method))
|
||||
return worker
|
||||
else:
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
|
||||
worker = TPUWorker(**self._get_worker_kwargs(
|
||||
local_rank, rank, distributed_init_method))
|
||||
return worker
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker."""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker."""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is currently not supported by the TPU backend.")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is currently not supported by the TPU backend.")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRA is currently not supported by the TPU backend.")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"LoRA is currently not supported by the TPU backend.")
|
||||
|
||||
def add_prompt_adapter(self, prompt_adapter_request) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the TPU backend.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the TPU backend.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the TPU backend.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Soft prompt is currently not supported by the TPU backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# TPUExecutor will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
sexecute_model_req: ExecuteModelRequest,
|
||||
) -> SamplerOutput:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(sexecute_model_req)
|
||||
return output
|
57
vllm/executor/uniproc_executor.py
Normal file
57
vllm/executor/uniproc_executor.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UniProcExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||
rank=0)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
local_rank = 0
|
||||
rank = 0
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
try:
|
||||
func = getattr(self.driver_worker, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"Method {method} is not implemented.") \
|
||||
from None
|
||||
answer = func(*args, **kwargs)
|
||||
return [answer]
|
||||
|
||||
def check_health(self) -> None:
|
||||
# UniProcExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
|
||||
UniProcExecutorAsync = UniProcExecutor
|
@ -1,39 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUExecutor(GPUExecutor):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "xpu"
|
||||
assert self.speculative_config is None, (
|
||||
"Speculative decoding not yet supported for XPU backend")
|
||||
|
||||
GPUExecutor._init_executor(self)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
|
||||
class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req)
|
||||
return output
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import psutil
|
||||
@ -105,6 +106,32 @@ class CpuPlatform(Platform):
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
|
||||
assert vllm_config.device_config.device_type == "cpu"
|
||||
|
||||
#
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Intel OpenMP setting
|
||||
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
||||
if "libiomp5.so" in ld_prealod_str:
|
||||
# The time(milliseconds) that a thread should wait after
|
||||
# completing the execution of a parallel region, before sleeping.
|
||||
os.environ['KMP_BLOCKTIME'] = "1"
|
||||
# Prevents the CPU to run into low performance state
|
||||
os.environ['KMP_TPAUSE'] = "0"
|
||||
# Provides fine granularity parallelism
|
||||
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
|
||||
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
|
||||
|
||||
# To hint IPEX uses shared memory based AllReduce
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on CPU.")
|
||||
|
@ -139,6 +139,28 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
world_size = parallel_config.world_size
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
@ -35,6 +35,14 @@ class NeuronPlatform(Platform):
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.neuron_worker.NeuronWorker"
|
||||
|
||||
if parallel_config.world_size > 1:
|
||||
parallel_config.distributed_executor_backend = "uni"
|
||||
|
||||
assert (vllm_config.lora_config is
|
||||
None), "LoRA is not supported for Neuron backend."
|
||||
assert (not vllm_config.speculative_config
|
||||
), "Speculative decoding not yet supported for Neuron backend."
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config:
|
||||
# neuron needs block_size = max_model_len
|
||||
|
@ -66,9 +66,8 @@ class OpenVinoPlatform(Platform):
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
assert (
|
||||
parallel_config.world_size == 1
|
||||
), "OpenVINOExecutor only supports single CPU socket currently."
|
||||
assert (parallel_config.world_size == 1
|
||||
), "OpenVINO only supports single CPU socket currently."
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = \
|
||||
@ -141,3 +140,10 @@ class OpenVinoPlatform(Platform):
|
||||
raise RuntimeError(
|
||||
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
assert vllm_config.device_config.device_type == "openvino"
|
||||
assert vllm_config.lora_config is None, \
|
||||
"OpenVINO backend doesn't support LoRA"
|
||||
assert cls.is_openvino_cpu() or \
|
||||
cls.is_openvino_gpu(), \
|
||||
"OpenVINO backend supports only CPU and GPU devices"
|
||||
|
@ -72,6 +72,16 @@ class TpuPlatform(Platform):
|
||||
assert vllm_config.speculative_config is None, \
|
||||
"TPU does not support speculative decoding"
|
||||
|
||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
|
||||
"Chunked prefill is not yet supported for TPU backend")
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
||||
vllm_config.model_config.dtype = torch.bfloat16
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
|
@ -78,17 +78,31 @@ class XPUPlatform(Platform):
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
|
||||
if vllm_config.device_config is not None:
|
||||
assert vllm_config.device_config.device_type == "xpu"
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.distributed_executor_backend is not None
|
||||
and parallel_config.distributed_executor_backend != "ray"):
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
elif parallel_config.distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':``
|
||||
# fork is not supported for xpu start new process.
|
||||
logger.error(
|
||||
"Both start methods (spawn and fork) have issue "
|
||||
"on XPU if you use mp backend, setting it to ray instead.")
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
elif parallel_config.distributed_executor_backend != "ray":
|
||||
logger.warning(
|
||||
"%s is not supported on XPU, fallback to ray distributed"
|
||||
" executor backend.",
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
|
@ -9,17 +9,15 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
|
||||
"""Worker for Medusa.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(kwargs.get("vllm_config"))
|
||||
self.init_worker(*args, **kwargs)
|
||||
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
|
@ -16,10 +16,10 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
@ -32,15 +32,12 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(kwargs.get("vllm_config"))
|
||||
self.init_worker(*args, **kwargs)
|
||||
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: SpeculativeProposer
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
@ -56,18 +53,6 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||
True)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return self.worker.get_cache_block_size_bytes()
|
||||
|
||||
def initialize_cache(self, *args, **kwargs) -> None:
|
||||
self.worker.initialize_cache(*args, **kwargs)
|
||||
|
||||
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
|
||||
return self.worker.execute_model(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
|
@ -40,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerWrapperBase)
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -64,8 +64,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
target_worker_config = copy.deepcopy(vllm_config)
|
||||
target_worker_config.parallel_config.worker_cls =\
|
||||
target_worker_config.parallel_config.sd_worker_cls
|
||||
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
|
||||
target_worker.init_worker(*args, **kwargs)
|
||||
cls = resolve_obj_by_qualname(
|
||||
target_worker_config.parallel_config.worker_cls)
|
||||
target_worker = cls(*args, **kwargs)
|
||||
# Set the disable_logprobs variable in the TargetModelRunner instance
|
||||
# as per its value specified in the SpeculativeConfig.
|
||||
target_worker.model_runner.disable_logprobs =\
|
||||
|
@ -14,8 +14,9 @@ class Executor(ABC):
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_executor import RayExecutor
|
||||
executor_class = RayExecutor
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
|
@ -246,9 +246,18 @@ class WorkerProc:
|
||||
ready_path: str,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||
wrapper.init_worker(vllm_config, local_rank, rank,
|
||||
distributed_init_method)
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: List[Dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
all_kwargs[rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
}
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper.worker
|
||||
|
||||
pid = os.getpid()
|
||||
@ -270,7 +279,7 @@ class WorkerProc:
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
ready_socket.send(payload)
|
||||
|
||||
self.worker.initialize()
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
|
@ -27,7 +27,7 @@ class UniprocExecutor(Executor):
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.worker: Worker = self._create_worker()
|
||||
self.worker.initialize()
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
def _create_worker(
|
||||
|
@ -33,6 +33,7 @@ class Worker:
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
@ -75,7 +76,7 @@ class Worker:
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def initialize(self):
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
@ -18,6 +19,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import bind_kv_cache
|
||||
@ -124,6 +126,70 @@ class HPUWorker(LocalOrDistributedWorkerBase):
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
assert execute_model_req is not None
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
|
||||
log_graph_compilation_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
|
||||
log_graph_compilation = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
|
||||
'0') != '0' or log_graph_compilation_all
|
||||
log_cpu_fallbacks_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
|
||||
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
|
||||
'0') != '0' or log_cpu_fallbacks_all
|
||||
if log_graph_compilation or log_cpu_fallbacks:
|
||||
from habana_frameworks.torch.hpu.metrics import metric_localcontext
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
is_prompt = any([
|
||||
seq_group_metadata.is_prompt
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
])
|
||||
max_context_len = max([
|
||||
max([
|
||||
len(v.prompt_token_ids) + len(v.output_token_ids)
|
||||
for v in seq_group_metadata.seq_data.values()
|
||||
]) for seq_group_metadata in seq_group_metadata_list
|
||||
]) # whoa, that's some spicy stuff right here
|
||||
max_num_blocks = (
|
||||
(max_context_len - 1) // self.cache_config.block_size) + 1
|
||||
input_stats = (f'is_prompt: {is_prompt}, '
|
||||
f'num_seqs: {len(seq_group_metadata_list)}, '
|
||||
f'max_context_len: {max_context_len}, '
|
||||
f'max_num_blocks {max_num_blocks}')
|
||||
gc_ctx = metric_localcontext(
|
||||
"graph_compilation"
|
||||
) if log_graph_compilation else contextlib.nullcontext()
|
||||
cpu_fallback_ctx = metric_localcontext(
|
||||
"cpu_fallback"
|
||||
) if log_cpu_fallbacks else contextlib.nullcontext()
|
||||
with gc_ctx as gc_local_metric, \
|
||||
cpu_fallback_ctx as cpu_fallback_local_metric:
|
||||
output = LocalOrDistributedWorkerBase.execute_model(
|
||||
self, execute_model_req)
|
||||
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
|
||||
) or log_graph_compilation_all:
|
||||
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
|
||||
f"{gc_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
|
||||
0) or log_cpu_fallbacks_all:
|
||||
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
|
||||
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
|
||||
return output
|
||||
|
||||
output = LocalOrDistributedWorkerBase.execute_model(
|
||||
self, execute_model_req)
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
|
@ -8,6 +8,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
@ -25,6 +26,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = True,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.local_rank = local_rank
|
||||
@ -37,7 +39,22 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
self.model_runner: NeuronModelRunner = NeuronModelRunner(
|
||||
vllm_config=vllm_config)
|
||||
self.is_driver_worker = True
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
assert execute_model_req is not None
|
||||
assert (not execute_model_req.blocks_to_swap_in
|
||||
and not execute_model_req.blocks_to_swap_out
|
||||
and not execute_model_req.blocks_to_copy), (
|
||||
"Cache operations are not supported for Neuron backend.")
|
||||
assert execute_model_req.num_lookahead_slots == 0, (
|
||||
"lookahead not supported for Neuron backend.")
|
||||
output = LocalOrDistributedWorkerBase.execute_model(
|
||||
self, execute_model_req)
|
||||
return output
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.init_distributed_environment()
|
||||
@ -103,13 +120,14 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def init_distributed_environment(self):
|
||||
"""Neuron uses transformers-neuronx for tensor parallelism.
|
||||
|
||||
vLLM still needs the environment inited when TP/PP > 1
|
||||
It has only one process to control multiple devices.
|
||||
vLLM still needs the environment initialized when TP/PP > 1,
|
||||
so we initialize a distributed environment with one process.
|
||||
"""
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
|
@ -211,16 +211,14 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.ov_core = ov_core
|
||||
WorkerBase.__init__(self, vllm_config)
|
||||
self.ov_core = ov.Core()
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
@ -237,7 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
self.model_runner = OpenVINOModelRunner(
|
||||
self.ov_core,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
kv_cache_dtype=self.vllm_config.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
|
@ -88,7 +88,6 @@ class WorkerBase(ABC):
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
@ -119,6 +118,58 @@ class WorkerBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
"""
|
||||
A class that delegates all methods to another WorkerBase instance. This is
|
||||
useful for creating a WorkerBase that wraps another WorkerBase instance,
|
||||
e.g. speculative decoding.
|
||||
"""
|
||||
worker: WorkerBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
vllm_config: VllmConfig = kwargs.get("vllm_config")
|
||||
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
|
||||
self.worker = cls(*args, **kwargs)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
return self.worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return self.worker.get_cache_block_size_bytes()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.worker.list_loras()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
class LoraNotSupportedWorkerBase(WorkerBase):
|
||||
"""Partial implementation of WorkerBase that raises exceptions when LoRA
|
||||
methods are invoked.
|
||||
@ -419,17 +470,31 @@ class WorkerWrapperBase:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.vllm_config = vllm_config
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
if vllm_config.model_config is not None:
|
||||
# it can be None in tests
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
@staticmethod
|
||||
def update_environment_variables(envs: Dict[str, str]) -> None:
|
||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rank of workers after we create all workers.
|
||||
"""
|
||||
if self.rank in rank_mapping:
|
||||
self.rank = rank_mapping[self.rank]
|
||||
|
||||
def update_environment_variables(self, envs_list: List[Dict[str,
|
||||
str]]) -> None:
|
||||
envs = envs_list[self.rank]
|
||||
key = 'CUDA_VISIBLE_DEVICES'
|
||||
if key in envs and key in os.environ:
|
||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||
@ -437,11 +502,12 @@ class WorkerWrapperBase:
|
||||
del os.environ[key]
|
||||
update_environment_variables(envs)
|
||||
|
||||
def init_worker(self, *args, **kwargs):
|
||||
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rank]
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
@ -452,7 +518,7 @@ class WorkerWrapperBase:
|
||||
|
||||
worker_class = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
|
Reference in New Issue
Block a user