mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)
This commit is contained in:
@ -19,3 +19,4 @@ openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entr
|
||||
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
requests
|
||||
zmq
|
||||
|
@ -7,7 +7,6 @@ from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
@ -43,13 +42,11 @@ def test_engine_core(monkeypatch):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class=executor_class)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
@ -151,13 +148,11 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class=executor_class)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
|
@ -86,11 +86,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
UsageContext.UNKNOWN_CONTEXT,
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
@ -158,11 +157,10 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
UsageContext.UNKNOWN_CONTEXT,
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
|
@ -68,7 +68,7 @@ from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
is_valid_ipv6_address, kill_process_tree, set_ulimit)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
@ -737,6 +737,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# The child processes will send SIGQUIT to this process when
|
||||
# any error happens. This process then clean up the whole tree.
|
||||
# TODO(rob): move this into AsyncLLM.__init__ once we remove
|
||||
# the context manager below.
|
||||
def sigquit_handler(signum, frame):
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -13,10 +12,9 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import cuda_is_initialized
|
||||
from vllm.utils import _check_multiproc_method, get_mp_context
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||
@ -274,24 +272,6 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
file.write = write_with_prefix # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _check_multiproc_method():
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"debugging.html#python-multiprocessing "
|
||||
"for more information.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
_check_multiproc_method()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def set_multiprocessing_worker_envs(parallel_config):
|
||||
""" Set up environment variables that should be used when there are workers
|
||||
in a multiprocessing environment. This should be called by the parent
|
||||
|
@ -10,6 +10,7 @@ import importlib.metadata
|
||||
import importlib.util
|
||||
import inspect
|
||||
import ipaddress
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import resource
|
||||
@ -20,6 +21,7 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
@ -29,8 +31,9 @@ from collections.abc import Hashable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generator, Generic, List, Literal, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
Dict, Generator, Generic, Iterator, List, Literal,
|
||||
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -39,6 +42,8 @@ import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
import yaml
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
@ -1844,7 +1849,7 @@ def memory_profiling(
|
||||
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535):
|
||||
"with error %s. This can cause fd limit errors like"
|
||||
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||
"increasing with ulimit -n", current_soft, e)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
return err_str
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||
path: str,
|
||||
type: Any,
|
||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
# For systems with substantial memory (>32GB total, >16GB available):
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
if total_mem > 32 and available_mem > 16:
|
||||
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
|
||||
else:
|
||||
buf_size = -1 # Use system default buffer size
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.setsockopt(zmq.constants.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.setsockopt(zmq.constants.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, type)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
def _check_multiproc_method():
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"debugging.html#python-multiprocessing "
|
||||
"for more information.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
_check_multiproc_method()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
@ -75,11 +75,11 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=usage_context,
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
|
@ -3,20 +3,19 @@ import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.executor.multiproc_worker_utils import get_mp_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
@ -25,14 +24,13 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
from vllm.v1.utils import make_zmq_socket
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
LOGGING_TIME_S = POLLING_TIMEOUT_S
|
||||
LOGGING_TIME_S = 5
|
||||
|
||||
|
||||
class EngineCore:
|
||||
@ -42,9 +40,10 @@ class EngineCore:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
log_stats: bool = False,
|
||||
):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
self.log_stats = log_stats
|
||||
|
||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
@ -134,29 +133,19 @@ class EngineCore:
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineCoreProcHandle:
|
||||
proc: BaseProcess
|
||||
ready_path: str
|
||||
input_path: str
|
||||
output_path: str
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_path: str,
|
||||
ready_pipe: Connection,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, usage_context)
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
@ -173,68 +162,7 @@ class EngineCoreProc(EngineCore):
|
||||
daemon=True).start()
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
ready_socket.send_string(EngineCoreProc.READY_STR)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
ready_path: str,
|
||||
) -> None:
|
||||
"""Wait until the EngineCore is ready."""
|
||||
|
||||
try:
|
||||
sync_ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket = sync_ctx.socket(zmq.constants.PULL)
|
||||
socket.connect(ready_path)
|
||||
|
||||
# Wait for EngineCore to send EngineCoreProc.READY_STR.
|
||||
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for EngineCoreProc to startup.")
|
||||
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("EngineCoreProc failed to start.")
|
||||
|
||||
message = socket.recv_string()
|
||||
assert message == EngineCoreProc.READY_STR
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
sync_ctx.destroy(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def make_engine_core_process(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_path: str,
|
||||
) -> EngineCoreProcHandle:
|
||||
context = get_mp_context()
|
||||
|
||||
process_kwargs = {
|
||||
"input_path": input_path,
|
||||
"output_path": output_path,
|
||||
"ready_path": ready_path,
|
||||
"vllm_config": vllm_config,
|
||||
"executor_class": executor_class,
|
||||
"usage_context": usage_context,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=EngineCoreProc.run_engine_core,
|
||||
kwargs=process_kwargs)
|
||||
proc.start()
|
||||
|
||||
# Wait for startup
|
||||
EngineCoreProc.wait_for_startup(proc, ready_path)
|
||||
return EngineCoreProcHandle(proc=proc,
|
||||
ready_path=ready_path,
|
||||
input_path=input_path,
|
||||
output_path=output_path)
|
||||
ready_pipe.send({"status": "READY"})
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args, **kwargs):
|
||||
@ -258,6 +186,7 @@ class EngineCoreProc(EngineCore):
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parent_process = psutil.Process().parent()
|
||||
engine_core = None
|
||||
try:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
@ -266,9 +195,10 @@ class EngineCoreProc(EngineCore):
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore interrupted.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error("EngineCore hit an exception: %s", traceback)
|
||||
parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
@ -309,6 +239,9 @@ class EngineCoreProc(EngineCore):
|
||||
def _log_stats(self):
|
||||
"""Log basic stats every LOGGING_TIME_S"""
|
||||
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
|
||||
if now - self._last_logging_time > LOGGING_TIME_S:
|
||||
@ -339,7 +272,7 @@ class EngineCoreProc(EngineCore):
|
||||
decoder_add_req = PickleEncoder()
|
||||
decoder_abort_req = PickleEncoder()
|
||||
|
||||
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
@ -367,7 +300,7 @@ class EngineCoreProc(EngineCore):
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
|
||||
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
engine_core_outputs = self.output_queue.get()
|
||||
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
||||
|
@ -1,19 +1,19 @@
|
||||
import os
|
||||
import weakref
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
||||
from vllm.utils import get_open_zmq_ipc_path
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
|
||||
EngineCoreProcHandle)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
from vllm.v1.utils import BackgroundProcHandle
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -31,10 +31,11 @@ class EngineCoreClient:
|
||||
|
||||
@staticmethod
|
||||
def make_client(
|
||||
*args,
|
||||
multiprocess_mode: bool,
|
||||
asyncio_mode: bool,
|
||||
**kwargs,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
) -> "EngineCoreClient":
|
||||
|
||||
# TODO: support this for debugging purposes.
|
||||
@ -44,12 +45,12 @@ class EngineCoreClient:
|
||||
"is not currently supported.")
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
return AsyncMPClient(*args, **kwargs)
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
if multiprocess_mode and not asyncio_mode:
|
||||
return SyncMPClient(*args, **kwargs)
|
||||
return SyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
return InprocClient(*args, **kwargs)
|
||||
return InprocClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
@ -128,9 +129,10 @@ class MPClient(EngineCoreClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
asyncio_mode: bool,
|
||||
**kwargs,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
):
|
||||
# Serialization setup.
|
||||
self.encoder = PickleEncoder()
|
||||
@ -143,7 +145,6 @@ class MPClient(EngineCoreClient):
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Path for IPC.
|
||||
ready_path = get_open_zmq_ipc_path()
|
||||
output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
|
||||
@ -156,47 +157,40 @@ class MPClient(EngineCoreClient):
|
||||
self.input_socket.bind(input_path)
|
||||
|
||||
# Start EngineCore in background process.
|
||||
self.proc_handle: Optional[EngineCoreProcHandle]
|
||||
self.proc_handle = EngineCoreProc.make_engine_core_process(
|
||||
*args,
|
||||
input_path=
|
||||
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
|
||||
output_path=output_path, # type: ignore[misc]
|
||||
ready_path=ready_path, # type: ignore[misc]
|
||||
**kwargs,
|
||||
)
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.proc_handle: Optional[BackgroundProcHandle]
|
||||
self.proc_handle = BackgroundProcHandle(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
process_name="EngineCore",
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
process_kwargs={
|
||||
"vllm_config": vllm_config,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
|
||||
def shutdown(self):
|
||||
# Shut down the zmq context.
|
||||
self.ctx.destroy(linger=0)
|
||||
|
||||
if hasattr(self, "proc_handle") and self.proc_handle:
|
||||
# Shutdown the process if needed.
|
||||
if self.proc_handle.proc.is_alive():
|
||||
self.proc_handle.proc.terminate()
|
||||
self.proc_handle.proc.join(5)
|
||||
|
||||
if self.proc_handle.proc.is_alive():
|
||||
kill_process_tree(self.proc_handle.proc.pid)
|
||||
|
||||
# Remove zmq ipc socket files
|
||||
ipc_sockets = [
|
||||
self.proc_handle.ready_path, self.proc_handle.output_path,
|
||||
self.proc_handle.input_path
|
||||
]
|
||||
for ipc_socket in ipc_sockets:
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
self.proc_handle.shutdown()
|
||||
self.proc_handle = None
|
||||
|
||||
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, asyncio_mode=False, **kwargs)
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False):
|
||||
super().__init__(
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
|
||||
@ -225,8 +219,16 @@ class SyncMPClient(MPClient):
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, asyncio_mode=True, **kwargs)
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False):
|
||||
super().__init__(
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
async def get_output_async(self) -> List[EngineCoreOutput]:
|
||||
|
||||
|
@ -72,11 +72,11 @@ class LLMEngine:
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
usage_context,
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -17,13 +17,12 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, get_mp_context, set_multiprocessing_worker_envs)
|
||||
_add_prefix, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||
get_open_zmq_ipc_path)
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import make_zmq_socket
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -250,7 +249,7 @@ class WorkerProc:
|
||||
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
||||
|
||||
# Send Readiness signal to EngineCore process.
|
||||
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
payload = pickle.dumps(worker_response_mq_handle,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
@ -352,7 +351,7 @@ class WorkerProc:
|
||||
ready_path: str,
|
||||
) -> Optional[Handle]:
|
||||
"""Wait until the Worker is ready."""
|
||||
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
|
||||
|
||||
# Wait for Worker to send READY.
|
||||
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
|
@ -1,11 +1,11 @@
|
||||
import os
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||
overload)
|
||||
|
||||
import zmq
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -77,27 +77,58 @@ class ConstantList(Generic[T], Sequence):
|
||||
return len(self._x)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_zmq_socket(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
class BackgroundProcHandle:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
process_name: str,
|
||||
target_fn: Callable,
|
||||
process_kwargs: Dict[Any, Any],
|
||||
):
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
context = get_mp_context()
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
yield socket
|
||||
assert ("ready_pipe" not in process_kwargs
|
||||
and "input_path" not in process_kwargs
|
||||
and "output_path" not in process_kwargs)
|
||||
process_kwargs["ready_pipe"] = writer
|
||||
process_kwargs["input_path"] = input_path
|
||||
process_kwargs["output_path"] = output_path
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Worker had Keyboard Interrupt.")
|
||||
# Run Detokenizer busy loop in background process.
|
||||
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
|
||||
self.proc.start()
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
# Wait for startup.
|
||||
if reader.recv()["status"] != "READY":
|
||||
raise RuntimeError(f"{process_name} initialization failed. "
|
||||
"See root cause above.")
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
# Shutdown the process if needed.
|
||||
if hasattr(self, "proc") and self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(5)
|
||||
|
||||
if self.proc.is_alive():
|
||||
kill_process_tree(self.proc.pid)
|
||||
|
||||
# Remove zmq ipc socket files
|
||||
ipc_sockets = [self.output_path, self.input_path]
|
||||
for ipc_socket in ipc_sockets:
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
Reference in New Issue
Block a user