[Core] Add multiproc_worker_utils for multiprocessing-based workers (#4357)

This commit is contained in:
Nick Hill
2024-05-01 11:41:59 -07:00
committed by GitHub
parent 24750f4cad
commit a657bfc48a
2 changed files with 440 additions and 0 deletions

View File

@ -0,0 +1,176 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import sleep
from typing import Any, List, Tuple
import pytest
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
class DummyWorker:
"""Dummy version of vllm.worker.worker.Worker"""
def __init__(self, rank: int):
self.rank = rank
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)
if isinstance(worker_input, Exception):
# simulate error case
raise worker_input
return self.rank, input
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
result_handler = ResultHandler()
workers = [
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
]
worker_monitor = WorkerMonitor(workers, result_handler)
assert not worker_monitor.is_alive()
result_handler.start()
worker_monitor.start()
assert worker_monitor.is_alive()
return workers, worker_monitor
def test_local_workers() -> None:
"""Test workers with sync task submission"""
workers, worker_monitor = _start_workers()
def execute_workers(worker_input: str) -> None:
worker_outputs = [
worker.execute_method("worker_method", worker_input)
for worker in workers
]
for rank, output in enumerate(worker_outputs):
assert output.get() == (rank, input)
executor = ThreadPoolExecutor(max_workers=4)
# Test concurrent submission from different threads
futures = [
executor.submit(partial(execute_workers, f"thread {thread_num}"))
for thread_num in range(4)
]
for future in futures:
future.result()
# Test error case
exception = ValueError("fake error")
result = workers[0].execute_method("worker_method", exception)
try:
result.get()
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"
# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()
# Other workers should get shut down here
worker_monitor.join(2)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)
# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""
workers, worker_monitor = _start_workers()
assert worker_monitor.is_alive()
assert all(worker.process.is_alive() for worker in workers)
# Clean shutdown
worker_monitor.close()
worker_monitor.join(5)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)
# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
@pytest.mark.asyncio
async def test_local_workers_async() -> None:
"""Test local workers with async task submission"""
workers, worker_monitor = _start_workers()
async def execute_workers(worker_input: str) -> None:
worker_coros = [
worker.execute_method_async("worker_method", worker_input)
for worker in workers
]
results = await asyncio.gather(*worker_coros)
for rank, result in enumerate(results):
assert result == (rank, input)
tasks = [
asyncio.create_task(execute_workers(f"task {task_num}"))
for task_num in range(4)
]
for task in tasks:
await task
# Test error case
exception = ValueError("fake error")
try:
_result = await workers[0].execute_method_async(
"worker_method", exception)
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"
# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()
# Other workers should get shut down here
worker_monitor.join(2)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)
# Further attempts to submit tasks should fail
try:
_result = await workers[0].execute_method_async(
"worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)

View File

@ -0,0 +1,264 @@
import asyncio
import multiprocessing
import os
import sys
import threading
import traceback
import uuid
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.connection import wait
from multiprocessing.process import BaseProcess
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)
from vllm.logger import init_logger
logger = init_logger(__name__)
T = TypeVar('T')
_TERMINATE = "TERMINATE" # sentinel
# ANSI color codes
CYAN = '\033[1;36m'
RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
mp = multiprocessing.get_context(mp_method)
@dataclass
class Result(Generic[T]):
"""Result of task dispatched to worker"""
task_id: uuid.UUID
value: Optional[T] = None
exception: Optional[BaseException] = None
class ResultFuture(threading.Event, Generic[T]):
"""Synchronous future for non-async case"""
def __init__(self):
super().__init__()
self.result: Optional[Result[T]] = None
def set_result(self, result: Result[T]):
self.result = result
self.set()
def get(self) -> T:
self.wait()
assert self.result is not None
if self.result.exception is not None:
raise self.result.exception
return self.result.value # type: ignore[return-value]
def _set_future_result(future: Union[ResultFuture, asyncio.Future],
result: Result):
if isinstance(future, ResultFuture):
future.set_result(result)
return
loop = future.get_loop()
if result.exception is not None:
loop.call_soon_threadsafe(future.set_exception, result.exception)
else:
loop.call_soon_threadsafe(future.set_result, result.value)
class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
for result in iter(self.result_queue.get, _TERMINATE):
future = self.tasks.pop(result.task_id)
_set_future_result(future, result)
# Ensure that all waiters will receive an exception
for task_id, future in self.tasks.items():
_set_future_result(
future,
Result(task_id=task_id,
exception=ChildProcessError("worker died")))
def close(self):
self.result_queue.put(_TERMINATE)
class WorkerMonitor(threading.Thread):
"""Monitor worker status (in background thread)"""
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
def run(self) -> None:
# Blocks until any worker exits
dead_sentinels = wait([w.process.sentinel for w in self.workers])
if not self._close:
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
for worker in self.workers:
worker.process.join(JOIN_TIMEOUT_S)
def close(self):
if self._close:
return
self._close = True
logger.info("Terminating local vLLM worker processes")
for worker in self.workers:
worker.terminate_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
class ProcessWorkerWrapper:
"""Local process wrapper for vllm.worker.Worker,
for handling single-node multi-GPU tensor parallel."""
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
worker_factory=worker_factory,
task_queue=self._task_queue,
result_queue=self.result_queue,
),
daemon=True)
self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future
async def execute_method_async(self, method: str, *args, **kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
def terminate_worker(self):
try:
self._task_queue.put(_TERMINATE)
except ValueError:
self.process.kill()
self._task_queue.close()
def kill_worker(self):
self._task_queue.close()
self.process.kill()
def _run_worker_process(
worker_factory: Callable[[], Any],
task_queue: Queue,
result_queue: Queue,
) -> None:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
# Initialize worker
worker = worker_factory()
del worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger.info("Worker ready; awaiting tasks")
try:
for items in iter(task_queue.get, _TERMINATE):
output = None
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except BaseException as e:
tb = traceback.format_exc()
logger.error(
"Exception in worker %s while processing method %s: %s, %s",
process_name, method, e, tb)
exception = e
result_queue.put(
Result(task_id=task_id, value=output, exception=exception))
except KeyboardInterrupt:
pass
except Exception:
logger.exception("Worker failed")
logger.info("Worker exiting")
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Prepend each output line with process-specific prefix"""
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
if not s:
return
if file.start_new_line: # type: ignore[attr-defined]
file_write(prefix)
idx = 0
while (next_idx := s.find('\n', idx)) != -1:
next_idx += 1
file_write(s[idx:next_idx])
if next_idx == len(s):
file.start_new_line = True # type: ignore[attr-defined]
return
file_write(prefix)
idx = next_idx
file_write(s[idx:])
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]