mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Core] Add multiproc_worker_utils
for multiprocessing-based workers (#4357)
This commit is contained in:
176
tests/engine/test_multiproc_workers.py
Normal file
176
tests/engine/test_multiproc_workers.py
Normal file
@ -0,0 +1,176 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
|
||||
|
||||
class DummyWorker:
|
||||
"""Dummy version of vllm.worker.worker.Worker"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
|
||||
sleep(0.05)
|
||||
|
||||
if isinstance(worker_input, Exception):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rank, input
|
||||
|
||||
|
||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
result_handler = ResultHandler()
|
||||
workers = [
|
||||
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
|
||||
for rank in range(8)
|
||||
]
|
||||
|
||||
worker_monitor = WorkerMonitor(workers, result_handler)
|
||||
assert not worker_monitor.is_alive()
|
||||
|
||||
result_handler.start()
|
||||
worker_monitor.start()
|
||||
assert worker_monitor.is_alive()
|
||||
|
||||
return workers, worker_monitor
|
||||
|
||||
|
||||
def test_local_workers() -> None:
|
||||
"""Test workers with sync task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
def execute_workers(worker_input: str) -> None:
|
||||
worker_outputs = [
|
||||
worker.execute_method("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
for rank, output in enumerate(worker_outputs):
|
||||
assert output.get() == (rank, input)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Test concurrent submission from different threads
|
||||
futures = [
|
||||
executor.submit(partial(execute_workers, f"thread {thread_num}"))
|
||||
for thread_num in range(4)
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
result = workers[0].execute_method("worker_method", exception)
|
||||
try:
|
||||
result.get()
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(2)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
def test_local_workers_clean_shutdown() -> None:
|
||||
"""Test clean shutdown"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
assert worker_monitor.is_alive()
|
||||
assert all(worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Clean shutdown
|
||||
worker_monitor.close()
|
||||
|
||||
worker_monitor.join(5)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_workers_async() -> None:
|
||||
"""Test local workers with async task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
async def execute_workers(worker_input: str) -> None:
|
||||
worker_coros = [
|
||||
worker.execute_method_async("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*worker_coros)
|
||||
for rank, result in enumerate(results):
|
||||
assert result == (rank, input)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(execute_workers(f"task {task_num}"))
|
||||
for task_num in range(4)
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", exception)
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(2)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
264
vllm/executor/multiproc_worker_utils.py
Normal file
264
vllm/executor/multiproc_worker_utils.py
Normal file
@ -0,0 +1,264 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.connection import wait
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||
TypeVar, Union)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
_TERMINATE = "TERMINATE" # sentinel
|
||||
|
||||
# ANSI color codes
|
||||
CYAN = '\033[1;36m'
|
||||
RESET = '\033[0;0m'
|
||||
|
||||
JOIN_TIMEOUT_S = 2
|
||||
|
||||
# Use dedicated multiprocess context for workers.
|
||||
# Both spawn and fork work
|
||||
mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
mp = multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result(Generic[T]):
|
||||
"""Result of task dispatched to worker"""
|
||||
|
||||
task_id: uuid.UUID
|
||||
value: Optional[T] = None
|
||||
exception: Optional[BaseException] = None
|
||||
|
||||
|
||||
class ResultFuture(threading.Event, Generic[T]):
|
||||
"""Synchronous future for non-async case"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.result: Optional[Result[T]] = None
|
||||
|
||||
def set_result(self, result: Result[T]):
|
||||
self.result = result
|
||||
self.set()
|
||||
|
||||
def get(self) -> T:
|
||||
self.wait()
|
||||
assert self.result is not None
|
||||
if self.result.exception is not None:
|
||||
raise self.result.exception
|
||||
return self.result.value # type: ignore[return-value]
|
||||
|
||||
|
||||
def _set_future_result(future: Union[ResultFuture, asyncio.Future],
|
||||
result: Result):
|
||||
if isinstance(future, ResultFuture):
|
||||
future.set_result(result)
|
||||
return
|
||||
loop = future.get_loop()
|
||||
if result.exception is not None:
|
||||
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
||||
else:
|
||||
loop.call_soon_threadsafe(future.set_result, result.value)
|
||||
|
||||
|
||||
class ResultHandler(threading.Thread):
|
||||
"""Handle results from all workers (in background thread)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self.result_queue = mp.Queue()
|
||||
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
|
||||
|
||||
def run(self):
|
||||
for result in iter(self.result_queue.get, _TERMINATE):
|
||||
future = self.tasks.pop(result.task_id)
|
||||
_set_future_result(future, result)
|
||||
# Ensure that all waiters will receive an exception
|
||||
for task_id, future in self.tasks.items():
|
||||
_set_future_result(
|
||||
future,
|
||||
Result(task_id=task_id,
|
||||
exception=ChildProcessError("worker died")))
|
||||
|
||||
def close(self):
|
||||
self.result_queue.put(_TERMINATE)
|
||||
|
||||
|
||||
class WorkerMonitor(threading.Thread):
|
||||
"""Monitor worker status (in background thread)"""
|
||||
|
||||
def __init__(self, workers: List['ProcessWorkerWrapper'],
|
||||
result_handler: ResultHandler):
|
||||
super().__init__(daemon=True)
|
||||
self.workers = workers
|
||||
self.result_handler = result_handler
|
||||
self._close = False
|
||||
|
||||
def run(self) -> None:
|
||||
# Blocks until any worker exits
|
||||
dead_sentinels = wait([w.process.sentinel for w in self.workers])
|
||||
if not self._close:
|
||||
self._close = True
|
||||
|
||||
# Kill / cleanup all workers
|
||||
for worker in self.workers:
|
||||
process = worker.process
|
||||
if process.sentinel in dead_sentinels:
|
||||
process.join(JOIN_TIMEOUT_S)
|
||||
if process.exitcode is not None and process.exitcode != 0:
|
||||
logger.error("Worker %s pid %s died, exit code: %s",
|
||||
process.name, process.pid, process.exitcode)
|
||||
# Cleanup any remaining workers
|
||||
logger.info("Killing local vLLM worker processes")
|
||||
for worker in self.workers:
|
||||
worker.kill_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
for worker in self.workers:
|
||||
worker.process.join(JOIN_TIMEOUT_S)
|
||||
|
||||
def close(self):
|
||||
if self._close:
|
||||
return
|
||||
self._close = True
|
||||
logger.info("Terminating local vLLM worker processes")
|
||||
for worker in self.workers:
|
||||
worker.terminate_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
|
||||
class ProcessWorkerWrapper:
|
||||
"""Local process wrapper for vllm.worker.Worker,
|
||||
for handling single-node multi-GPU tensor parallel."""
|
||||
|
||||
def __init__(self, result_handler: ResultHandler,
|
||||
worker_factory: Callable[[], Any]) -> None:
|
||||
self._task_queue = mp.Queue()
|
||||
self.result_queue = result_handler.result_queue
|
||||
self.tasks = result_handler.tasks
|
||||
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
|
||||
target=_run_worker_process,
|
||||
name="VllmWorkerProcess",
|
||||
kwargs=dict(
|
||||
worker_factory=worker_factory,
|
||||
task_queue=self._task_queue,
|
||||
result_queue=self.result_queue,
|
||||
),
|
||||
daemon=True)
|
||||
|
||||
self.process.start()
|
||||
|
||||
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
|
||||
method: str, args, kwargs):
|
||||
task_id = uuid.uuid4()
|
||||
self.tasks[task_id] = future
|
||||
try:
|
||||
self._task_queue.put((task_id, method, args, kwargs))
|
||||
except BaseException as e:
|
||||
del self.tasks[task_id]
|
||||
raise ChildProcessError("worker died") from e
|
||||
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
future: ResultFuture = ResultFuture()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return future
|
||||
|
||||
async def execute_method_async(self, method: str, *args, **kwargs):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return await future
|
||||
|
||||
def terminate_worker(self):
|
||||
try:
|
||||
self._task_queue.put(_TERMINATE)
|
||||
except ValueError:
|
||||
self.process.kill()
|
||||
self._task_queue.close()
|
||||
|
||||
def kill_worker(self):
|
||||
self._task_queue.close()
|
||||
self.process.kill()
|
||||
|
||||
|
||||
def _run_worker_process(
|
||||
worker_factory: Callable[[], Any],
|
||||
task_queue: Queue,
|
||||
result_queue: Queue,
|
||||
) -> None:
|
||||
"""Worker process event loop"""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr
|
||||
process_name = mp.current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
# Initialize worker
|
||||
worker = worker_factory()
|
||||
del worker_factory
|
||||
|
||||
# Accept tasks from the engine in task_queue
|
||||
# and return task output in result_queue
|
||||
logger.info("Worker ready; awaiting tasks")
|
||||
try:
|
||||
for items in iter(task_queue.get, _TERMINATE):
|
||||
output = None
|
||||
exception = None
|
||||
task_id, method, args, kwargs = items
|
||||
try:
|
||||
executor = getattr(worker, method)
|
||||
output = executor(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
tb = traceback.format_exc()
|
||||
logger.error(
|
||||
"Exception in worker %s while processing method %s: %s, %s",
|
||||
process_name, method, e, tb)
|
||||
exception = e
|
||||
result_queue.put(
|
||||
Result(task_id=task_id, value=output, exception=exception))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Worker failed")
|
||||
|
||||
logger.info("Worker exiting")
|
||||
|
||||
|
||||
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
"""Prepend each output line with process-specific prefix"""
|
||||
|
||||
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||
file_write = file.write
|
||||
|
||||
def write_with_prefix(s: str):
|
||||
if not s:
|
||||
return
|
||||
if file.start_new_line: # type: ignore[attr-defined]
|
||||
file_write(prefix)
|
||||
idx = 0
|
||||
while (next_idx := s.find('\n', idx)) != -1:
|
||||
next_idx += 1
|
||||
file_write(s[idx:next_idx])
|
||||
if next_idx == len(s):
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
return
|
||||
file_write(prefix)
|
||||
idx = next_idx
|
||||
file_write(s[idx:])
|
||||
file.start_new_line = False # type: ignore[attr-defined]
|
||||
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
file.write = write_with_prefix # type: ignore[method-assign]
|
Reference in New Issue
Block a user