[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)

This commit is contained in:
Robert Shaw
2024-12-27 20:45:08 -05:00
committed by GitHub
parent a60731247f
commit df04dffade
12 changed files with 242 additions and 210 deletions

View File

@ -19,3 +19,4 @@ openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entr
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requests
zmq

View File

@ -7,7 +7,6 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore
@ -43,13 +42,11 @@ def test_engine_core(monkeypatch):
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class=executor_class)
"""Test basic request lifecycle."""
# First request.
@ -151,13 +148,11 @@ def test_engine_core_advanced_sampling(monkeypatch):
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class=executor_class)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()

View File

@ -86,11 +86,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
)
MAX_TOKENS = 20
@ -158,11 +157,10 @@ async def test_engine_core_client_asyncio(monkeypatch):
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
)
MAX_TOKENS = 20

View File

@ -68,7 +68,7 @@ from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
is_valid_ipv6_address, kill_process_tree, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -737,6 +737,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
# The child processes will send SIGQUIT to this process when
# any error happens. This process then clean up the whole tree.
# TODO(rob): move this into AsyncLLM.__init__ once we remove
# the context manager below.
def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)

View File

@ -1,5 +1,4 @@
import asyncio
import multiprocessing
import os
import sys
import threading
@ -13,10 +12,9 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import cuda_is_initialized
from vllm.utils import _check_multiproc_method, get_mp_context
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
@ -274,24 +272,6 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.write = write_with_prefix # type: ignore[method-assign]
def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def set_multiprocessing_worker_envs(parallel_config):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent

View File

@ -10,6 +10,7 @@ import importlib.metadata
import importlib.util
import inspect
import ipaddress
import multiprocessing
import os
import re
import resource
@ -20,6 +21,7 @@ import sys
import tempfile
import threading
import time
import traceback
import uuid
import warnings
import weakref
@ -29,8 +31,9 @@ from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, List, Literal, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, overload)
Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
overload)
from uuid import uuid4
import numpy as np
@ -39,6 +42,8 @@ import psutil
import torch
import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
@ -1844,7 +1849,7 @@ def memory_profiling(
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535):
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n", current_soft, e)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
type: Any,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem = psutil.virtual_memory()
socket = ctx.socket(type)
# Calculate buffer size based on system memory
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
# For systems with substantial memory (>32GB total, >16GB available):
# - Set a large 0.5GB buffer to improve throughput
# For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption
if total_mem > 32 and available_mem > 16:
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
else:
buf_size = -1 # Use system default buffer size
if type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
return socket
@contextlib.contextmanager
def zmq_socket_ctx(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, type)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)

View File

@ -75,11 +75,11 @@ class AsyncLLM(EngineClient):
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
vllm_config=vllm_config,
executor_class=executor_class,
usage_context=usage_context,
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.output_handler: Optional[asyncio.Task] = None

View File

@ -3,20 +3,19 @@ import queue
import signal
import threading
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
import psutil
import zmq
import zmq.asyncio
from msgspec import msgpack
from vllm.config import CacheConfig, VllmConfig
from vllm.executor.multiproc_worker_utils import get_mp_context
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
@ -25,14 +24,13 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = POLLING_TIMEOUT_S
LOGGING_TIME_S = 5
class EngineCore:
@ -42,9 +40,10 @@ class EngineCore:
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
@ -134,29 +133,19 @@ class EngineCore:
self.model_executor.profile(is_start)
@dataclass
class EngineCoreProcHandle:
proc: BaseProcess
ready_path: str
input_path: str
output_path: str
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class, usage_context)
super().__init__(vllm_config, executor_class, log_stats)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
@ -173,68 +162,7 @@ class EngineCoreProc(EngineCore):
daemon=True).start()
# Send Readiness signal to EngineClient.
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> None:
"""Wait until the EngineCore is ready."""
try:
sync_ctx = zmq.Context() # type: ignore[attr-defined]
socket = sync_ctx.socket(zmq.constants.PULL)
socket.connect(ready_path)
# Wait for EngineCore to send EngineCoreProc.READY_STR.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for EngineCoreProc to startup.")
if not proc.is_alive():
raise RuntimeError("EngineCoreProc failed to start.")
message = socket.recv_string()
assert message == EngineCoreProc.READY_STR
except BaseException as e:
logger.exception(e)
raise e
finally:
sync_ctx.destroy(linger=0)
@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
) -> EngineCoreProcHandle:
context = get_mp_context()
process_kwargs = {
"input_path": input_path,
"output_path": output_path,
"ready_path": ready_path,
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
kwargs=process_kwargs)
proc.start()
# Wait for startup
EngineCoreProc.wait_for_startup(proc, ready_path)
return EngineCoreProcHandle(proc=proc,
ready_path=ready_path,
input_path=input_path,
output_path=output_path)
ready_pipe.send({"status": "READY"})
@staticmethod
def run_engine_core(*args, **kwargs):
@ -258,6 +186,7 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
@ -266,9 +195,10 @@ class EngineCoreProc(EngineCore):
except SystemExit:
logger.debug("EngineCore interrupted.")
except BaseException as e:
logger.exception(e)
raise e
except Exception:
traceback = get_exception_traceback()
logger.error("EngineCore hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGQUIT)
finally:
if engine_core is not None:
@ -309,6 +239,9 @@ class EngineCoreProc(EngineCore):
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""
if not self.log_stats:
return
now = time.time()
if now - self._last_logging_time > LOGGING_TIME_S:
@ -339,7 +272,7 @@ class EngineCoreProc(EngineCore):
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
@ -367,7 +300,7 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer.
buffer = bytearray()
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)

View File

@ -1,19 +1,19 @@
import os
import weakref
from typing import List, Optional
from typing import List, Optional, Type
import msgspec
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
@ -31,10 +31,11 @@ class EngineCoreClient:
@staticmethod
def make_client(
*args,
multiprocess_mode: bool,
asyncio_mode: bool,
**kwargs,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
) -> "EngineCoreClient":
# TODO: support this for debugging purposes.
@ -44,12 +45,12 @@ class EngineCoreClient:
"is not currently supported.")
if multiprocess_mode and asyncio_mode:
return AsyncMPClient(*args, **kwargs)
return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(*args, **kwargs)
return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(*args, **kwargs)
return InprocClient(vllm_config, executor_class, log_stats)
def shutdown(self):
pass
@ -128,9 +129,10 @@ class MPClient(EngineCoreClient):
def __init__(
self,
*args,
asyncio_mode: bool,
**kwargs,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
# Serialization setup.
self.encoder = PickleEncoder()
@ -143,7 +145,6 @@ class MPClient(EngineCoreClient):
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Path for IPC.
ready_path = get_open_zmq_ipc_path()
output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
@ -156,47 +157,40 @@ class MPClient(EngineCoreClient):
self.input_socket.bind(input_path)
# Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process(
*args,
input_path=
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs,
)
self._finalizer = weakref.finalize(self, self.shutdown)
self.proc_handle: Optional[BackgroundProcHandle]
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
"log_stats": log_stats,
})
def shutdown(self):
# Shut down the zmq context.
self.ctx.destroy(linger=0)
if hasattr(self, "proc_handle") and self.proc_handle:
# Shutdown the process if needed.
if self.proc_handle.proc.is_alive():
self.proc_handle.proc.terminate()
self.proc_handle.proc.join(5)
if self.proc_handle.proc.is_alive():
kill_process_tree(self.proc_handle.proc.pid)
# Remove zmq ipc socket files
ipc_sockets = [
self.proc_handle.ready_path, self.proc_handle.output_path,
self.proc_handle.input_path
]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
self.proc_handle.shutdown()
self.proc_handle = None
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=False, **kwargs)
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
)
def get_output(self) -> List[EngineCoreOutput]:
@ -225,8 +219,16 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=True, **kwargs)
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
)
async def get_output_async(self) -> List[EngineCoreOutput]:

View File

@ -72,11 +72,11 @@ class LLMEngine:
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
vllm_config,
executor_class,
usage_context,
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
@classmethod

View File

@ -17,13 +17,12 @@ from vllm.distributed import (destroy_distributed_environment,
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.multiproc_worker_utils import (
_add_prefix, get_mp_context, set_multiprocessing_worker_envs)
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_open_port,
get_open_zmq_ipc_path)
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import make_zmq_socket
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -250,7 +249,7 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
@ -352,7 +351,7 @@ class WorkerProc:
ready_path: str,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:

View File

@ -1,11 +1,11 @@
import os
import weakref
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
overload)
import zmq
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
Union, overload)
from vllm.logger import init_logger
from vllm.utils import get_mp_context, kill_process_tree
logger = init_logger(__name__)
@ -77,27 +77,58 @@ class ConstantList(Generic[T], Sequence):
return len(self._x)
@contextmanager
def make_zmq_socket(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
class BackgroundProcHandle:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
socket = ctx.socket(type)
def __init__(
self,
input_path: str,
output_path: str,
process_name: str,
target_fn: Callable,
process_kwargs: Dict[Any, Any],
):
self._finalizer = weakref.finalize(self, self.shutdown)
if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)
yield socket
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
self.input_path = input_path
self.output_path = output_path
except KeyboardInterrupt:
logger.debug("Worker had Keyboard Interrupt.")
# Run Detokenizer busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self.proc.start()
finally:
ctx.destroy(linger=0)
# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.")
def __del__(self):
self.shutdown()
def shutdown(self):
# Shutdown the process if needed.
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
self.proc.join(5)
if self.proc.is_alive():
kill_process_tree(self.proc.pid)
# Remove zmq ipc socket files
ipc_sockets = [self.output_path, self.input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)