[Bug][Frontend] Improve ZMQ client robustness (#7443)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-08-21 20:18:11 -06:00
committed by GitHub
parent df1a21131d
commit cde9183b40
6 changed files with 176 additions and 28 deletions

View File

View File

@ -0,0 +1,119 @@
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid
import pytest
import pytest_asyncio
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()
def dummy_engine_builder(*args, **kwargs):
return dummy_engine
with monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
try:
yield server
finally:
server_task.cancel()
server.cleanup()
@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()
try:
yield client
finally:
client.close()
@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# Ensure the client doesn't hang
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id"))
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")
def raiser():
raise exception
m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)
client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
client.close()
# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass
# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()

View File

@ -6,7 +6,7 @@ import os
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set
@ -83,6 +83,7 @@ async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
with suppress(Exception):
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:

View File

@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000

View File

@ -1,5 +1,5 @@
import asyncio
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4
@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTH_TIMEOUT_MS,
VLLM_RPC_SERVER_START_TIMEOUT_MS,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -32,6 +31,17 @@ logger = init_logger(__name__)
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
# Wait until server is ready.
await self._wait_for_server_rpc()
self._errored = False
# Get the configs.
self.model_config = await self._get_model_config_rpc()
@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
@contextmanager
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms")
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv())
# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server",
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
error_message="Unable to start RPC Server")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
@ -308,14 +335,14 @@ class AsyncEngineRPCClient:
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
socket=socket)
async def encode(self, *args,

View File

@ -56,6 +56,7 @@ if TYPE_CHECKING:
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
@ -374,6 +375,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045