Files
vllm-dev/vllm/v1/executor/multiproc_executor.py
2025-08-29 08:17:27 -07:00

622 lines
24 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os
import pickle
import signal
import threading
import time
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import cloudpickle
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.multiproc_worker_utils import (
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class MultiprocExecutor(Executor):
supports_pp: bool = True
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# get_loopback_ip() for communication.
distributed_init_method = get_distributed_init_method(
get_loopback_ip(), get_open_port())
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(self.world_size,
self.world_size,
max_chunk_bytes=max_chunk_bytes)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
))
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
# Close death_writers first to signal workers to exit
for uw in unready_workers:
if uw.death_writer is not None:
uw.death_writer.close()
self._ensure_worker_termination(
[uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
# Monitors worker process liveness. If any die unexpectedly,
# logs an error, shuts down the executor and invokes the failure
# callback to inform the engine.
def monitor_workers():
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
non_block = self.max_concurrent_batches > 1
if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
# aggregate all workers output to a single output
if non_block:
return self.kv_output_aggregator.async_aggregate(
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch",
unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
unique_reply_rank=self.output_rank)
return outputs[0]
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
return result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout)
else:
raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1")
responses.append(result)
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
def wait_for_termination(procs, timeout):
if not time:
# If we are in late stage shutdown, the interpreter may replace
# `time` with `None`.
return all(not proc.is_alive() for proc in procs)
start_time = time.time()
while time.time() - start_time < timeout:
if all(not proc.is_alive() for proc in procs):
return True
time.sleep(0.1)
return False
# Send SIGTERM if still running
active_procs = [proc for proc in worker_procs if proc.is_alive()]
for p in active_procs:
p.terminate()
if not wait_for_termination(active_procs, 4):
# Send SIGKILL if still running
active_procs = [p for p in active_procs if p.is_alive()]
for p in active_procs:
p.kill()
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
if workers := getattr(self, 'workers', None):
for w in workers:
# Close death_writer to signal child processes to exit
if w.death_writer is not None:
w.death_writer.close()
w.death_writer = None
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in workers])
self.rpc_broadcast_mq = None
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@property
def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
death_writer: Optional[Connection] = None
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
death_writer: Optional[Connection] = None
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
death_writer=unready_handle.death_writer,
)
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
pp_size = vllm_config.parallel_config.pipeline_parallel_size
tp_size = vllm_config.parallel_config.tensor_parallel_size
pp_str = f"PP{rank // tp_size}" if pp_size > 1 else ""
tp_str = f"TP{rank % tp_size}" if tp_size > 1 else ""
suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}"
process_name = "VllmWorker"
if suffix:
set_process_title(suffix, append=True)
process_name = f"{process_name} {suffix}"
decorate_logs(process_name)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# (reader, writer)
reader, writer = context.Pipe(duplex=False)
# Create death pipe to detect parent process exit
death_reader, death_writer = context.Pipe(duplex=False)
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_pipe": (reader, writer),
"death_pipe": death_reader,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True)
proc.start()
writer.close()
# Keep death_writer open in parent - when parent exits,
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
) -> list[WorkerProcHandle]:
e = Exception("WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause.")
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
@staticmethod
def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
# Start death monitoring thread if death_pipe is provided
if death_pipe is not None:
def monitor_parent_death():
try:
# This will block until parent process exits (pipe closes)
death_pipe.recv()
except EOFError:
# Parent process has exited, terminate this worker
logger.info("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown
os.kill(os.getpid(), signal.SIGTERM)
except Exception as e:
logger.warning("Death monitoring error: %s", e)
death_monitor = Thread(target=monitor_parent_death,
daemon=True,
name="WorkerDeathMonitor")
death_monitor.start()
try:
reader.close()
worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally:
if ready_writer is not None:
ready_writer.close()
if death_pipe is not None:
death_pipe.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))