mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Bug][Frontend] Improve ZMQ client robustness (#7443)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
		
							
								
								
									
										0
									
								
								tests/entrypoints/openai/rpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/entrypoints/openai/rpc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										119
									
								
								tests/entrypoints/openai/rpc/test_zmq_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								tests/entrypoints/openai/rpc/test_zmq_client.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,119 @@ | |||||||
|  | import asyncio | ||||||
|  | import tempfile | ||||||
|  | import unittest | ||||||
|  | import unittest.mock | ||||||
|  | import uuid | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | import pytest_asyncio | ||||||
|  |  | ||||||
|  | from vllm.engine.async_llm_engine import AsyncLLMEngine | ||||||
|  | from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, | ||||||
|  |                                                 RPCClientClosedError) | ||||||
|  | from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture(scope="function") | ||||||
|  | def tmp_socket(): | ||||||
|  |     with tempfile.TemporaryDirectory() as td: | ||||||
|  |         yield f"ipc://{td}/{uuid.uuid4()}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest_asyncio.fixture(scope="function") | ||||||
|  | async def dummy_server(tmp_socket, monkeypatch): | ||||||
|  |     dummy_engine = unittest.mock.AsyncMock() | ||||||
|  |  | ||||||
|  |     def dummy_engine_builder(*args, **kwargs): | ||||||
|  |         return dummy_engine | ||||||
|  |  | ||||||
|  |     with monkeypatch.context() as m: | ||||||
|  |         m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) | ||||||
|  |         server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) | ||||||
|  |  | ||||||
|  |     loop = asyncio.get_running_loop() | ||||||
|  |     server_task = loop.create_task(server.run_server_loop()) | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         yield server | ||||||
|  |     finally: | ||||||
|  |         server_task.cancel() | ||||||
|  |         server.cleanup() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest_asyncio.fixture(scope="function") | ||||||
|  | async def client(tmp_socket): | ||||||
|  |     client = AsyncEngineRPCClient(rpc_path=tmp_socket) | ||||||
|  |     # Sanity check: the server is connected | ||||||
|  |     await client._wait_for_server_rpc() | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         yield client | ||||||
|  |     finally: | ||||||
|  |         client.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, | ||||||
|  |                                                 client: AsyncEngineRPCClient): | ||||||
|  |     with monkeypatch.context() as m: | ||||||
|  |         # Make the server _not_ reply with a model config | ||||||
|  |         m.setattr(dummy_server, "get_config", lambda x: None) | ||||||
|  |         m.setattr(client, "_data_timeout", 10) | ||||||
|  |  | ||||||
|  |         # And ensure the task completes anyway | ||||||
|  |         # (client.setup() invokes server.get_config()) | ||||||
|  |         client_task = asyncio.get_running_loop().create_task(client.setup()) | ||||||
|  |         with pytest.raises(TimeoutError, match="Server didn't reply within"): | ||||||
|  |             await asyncio.wait_for(client_task, timeout=0.05) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, | ||||||
|  |                                           client: AsyncEngineRPCClient): | ||||||
|  |     with monkeypatch.context() as m: | ||||||
|  |         # Hang all abort requests | ||||||
|  |         m.setattr(dummy_server, "abort", lambda x: None) | ||||||
|  |         m.setattr(client, "_data_timeout", 10) | ||||||
|  |  | ||||||
|  |         # Ensure the client doesn't hang | ||||||
|  |         client_task = asyncio.get_running_loop().create_task( | ||||||
|  |             client.abort("test request id")) | ||||||
|  |         with pytest.raises(TimeoutError, match="Server didn't reply within"): | ||||||
|  |             await asyncio.wait_for(client_task, timeout=0.05) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_client_data_methods_reraise_exceptions( | ||||||
|  |         monkeypatch, dummy_server, client: AsyncEngineRPCClient): | ||||||
|  |     with monkeypatch.context() as m: | ||||||
|  |         # Make the server raise some random exception | ||||||
|  |         exception = RuntimeError("Client test exception") | ||||||
|  |  | ||||||
|  |         def raiser(): | ||||||
|  |             raise exception | ||||||
|  |  | ||||||
|  |         m.setattr(dummy_server.engine, "get_model_config", raiser) | ||||||
|  |         m.setattr(client, "_data_timeout", 10) | ||||||
|  |  | ||||||
|  |         client_task = asyncio.get_running_loop().create_task(client.setup()) | ||||||
|  |         # And ensure the task completes, raising the exception | ||||||
|  |         with pytest.raises(RuntimeError, match=str(exception)): | ||||||
|  |             await asyncio.wait_for(client_task, timeout=0.05) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_client_errors_after_closing(monkeypatch, dummy_server, | ||||||
|  |                                            client: AsyncEngineRPCClient): | ||||||
|  |  | ||||||
|  |     client.close() | ||||||
|  |  | ||||||
|  |     # Healthchecks and generate requests will fail with explicit errors | ||||||
|  |     with pytest.raises(RPCClientClosedError): | ||||||
|  |         await client.check_health() | ||||||
|  |     with pytest.raises(RPCClientClosedError): | ||||||
|  |         async for _ in client.generate(None, None, None): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     # But no-ops like aborting will pass | ||||||
|  |     await client.abort("test-request-id") | ||||||
|  |     await client.do_log_stats() | ||||||
| @ -6,7 +6,7 @@ import os | |||||||
| import re | import re | ||||||
| import tempfile | import tempfile | ||||||
| from argparse import Namespace | from argparse import Namespace | ||||||
| from contextlib import asynccontextmanager | from contextlib import asynccontextmanager, suppress | ||||||
| from http import HTTPStatus | from http import HTTPStatus | ||||||
| from typing import AsyncIterator, Optional, Set | from typing import AsyncIterator, Optional, Set | ||||||
|  |  | ||||||
| @ -83,7 +83,8 @@ async def lifespan(app: FastAPI): | |||||||
|     async def _force_log(): |     async def _force_log(): | ||||||
|         while True: |         while True: | ||||||
|             await asyncio.sleep(10) |             await asyncio.sleep(10) | ||||||
|             await async_engine_client.do_log_stats() |             with suppress(Exception): | ||||||
|  |                 await async_engine_client.do_log_stats() | ||||||
|  |  | ||||||
|     if not engine_args.disable_log_stats: |     if not engine_args.disable_log_stats: | ||||||
|         task = asyncio.create_task(_force_log()) |         task = asyncio.create_task(_force_log()) | ||||||
|  | |||||||
| @ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams | |||||||
| # Success string used for RPC instructions. | # Success string used for RPC instructions. | ||||||
| VLLM_RPC_SUCCESS_STR = "SUCCESS" | VLLM_RPC_SUCCESS_STR = "SUCCESS" | ||||||
|  |  | ||||||
| # Timeouts. |  | ||||||
| VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000 |  | ||||||
| VLLM_RPC_HEALTH_TIMEOUT_MS = 10000 |  | ||||||
|  |  | ||||||
| # Minimum value of ZMQ.SOCKET_LIMIT to run mp. | # Minimum value of ZMQ.SOCKET_LIMIT to run mp. | ||||||
| VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 | VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| import asyncio | import asyncio | ||||||
| from contextlib import contextmanager | from contextlib import contextmanager, suppress | ||||||
| from typing import Any, AsyncGenerator, Mapping, Optional | from typing import Any, AsyncGenerator, Mapping, Optional | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| @ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, | |||||||
|                          ParallelConfig, SchedulerConfig) |                          ParallelConfig, SchedulerConfig) | ||||||
| # yapf: disable | # yapf: disable | ||||||
| from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, | from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, | ||||||
|                                          VLLM_RPC_HEALTH_TIMEOUT_MS, |  | ||||||
|                                          VLLM_RPC_SERVER_START_TIMEOUT_MS, |  | ||||||
|                                          VLLM_RPC_SOCKET_LIMIT_CUTOFF, |                                          VLLM_RPC_SOCKET_LIMIT_CUTOFF, | ||||||
|                                          VLLM_RPC_SUCCESS_STR, |                                          VLLM_RPC_SUCCESS_STR, | ||||||
|                                          VLLM_RPC_ZMQ_HWM, RPCAbortRequest, |                                          VLLM_RPC_ZMQ_HWM, RPCAbortRequest, | ||||||
|                                          RPCGenerateRequest, RPCUtilityRequest) |                                          RPCGenerateRequest, RPCUtilityRequest) | ||||||
| # yapf: enable | # yapf: enable | ||||||
|  | from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS | ||||||
| from vllm.inputs import PromptInputs | from vllm.inputs import PromptInputs | ||||||
| from vllm.logger import init_logger | from vllm.logger import init_logger | ||||||
| from vllm.lora.request import LoRARequest | from vllm.lora.request import LoRARequest | ||||||
| @ -32,6 +31,17 @@ logger = init_logger(__name__) | |||||||
| INPROC_PROXY_PATH = f"inproc://{uuid4()}" | INPROC_PROXY_PATH = f"inproc://{uuid4()}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RPCClientClosedError(Exception): | ||||||
|  |     """Exception class raised when the client is used post-close. | ||||||
|  |      | ||||||
|  |     The client can be closed, which closes the ZMQ context. This normally | ||||||
|  |     happens on server shutdown. In some cases, methods like abort and  | ||||||
|  |     do_log_stats will still be called and then try to open a socket, which  | ||||||
|  |     causes a ZMQError and creates a huge stack trace. | ||||||
|  |     So, we throw this error such that we can suppress it. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
| class AsyncEngineRPCClient: | class AsyncEngineRPCClient: | ||||||
|     """ |     """ | ||||||
|     RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. |     RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. | ||||||
| @ -85,6 +95,8 @@ class AsyncEngineRPCClient: | |||||||
|  |  | ||||||
|     def __init__(self, rpc_path: str): |     def __init__(self, rpc_path: str): | ||||||
|         self.context = zmq.asyncio.Context() |         self.context = zmq.asyncio.Context() | ||||||
|  |         self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS | ||||||
|  |         self._errored = False | ||||||
|  |  | ||||||
|         # Maximum number of sockets that can be opened (typically 65536). |         # Maximum number of sockets that can be opened (typically 65536). | ||||||
|         # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) |         # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) | ||||||
| @ -143,7 +155,6 @@ class AsyncEngineRPCClient: | |||||||
|  |  | ||||||
|         # Wait until server is ready. |         # Wait until server is ready. | ||||||
|         await self._wait_for_server_rpc() |         await self._wait_for_server_rpc() | ||||||
|         self._errored = False |  | ||||||
|  |  | ||||||
|         # Get the configs. |         # Get the configs. | ||||||
|         self.model_config = await self._get_model_config_rpc() |         self.model_config = await self._get_model_config_rpc() | ||||||
| @ -170,6 +181,15 @@ class AsyncEngineRPCClient: | |||||||
|     @contextmanager |     @contextmanager | ||||||
|     def to_proxy_socket(self): |     def to_proxy_socket(self): | ||||||
|         # Connect to the RPCServer via the proxy. |         # Connect to the RPCServer via the proxy. | ||||||
|  |  | ||||||
|  |         # Raise a sensible error if the client was already closed. | ||||||
|  |         # This can happen if a server shutdown is triggered but some coroutines | ||||||
|  |         # are still running requests. | ||||||
|  |         # There should not be a race condition with this check because we don't | ||||||
|  |         # yield to the event loop between here and opening the socket. | ||||||
|  |         if self.context.closed: | ||||||
|  |             raise RPCClientClosedError("The ZMQ client has already shut down") | ||||||
|  |  | ||||||
|         # Note that we use DEALER to enable asynchronous communication |         # Note that we use DEALER to enable asynchronous communication | ||||||
|         # to enable streaming. |         # to enable streaming. | ||||||
|         socket = self.context.socket(zmq.constants.DEALER) |         socket = self.context.socket(zmq.constants.DEALER) | ||||||
| @ -189,9 +209,18 @@ class AsyncEngineRPCClient: | |||||||
|             # Ping RPCServer with a request. |             # Ping RPCServer with a request. | ||||||
|             await socket.send_multipart([cloudpickle.dumps(request)]) |             await socket.send_multipart([cloudpickle.dumps(request)]) | ||||||
|  |  | ||||||
|  |             # Make sure the server responds | ||||||
|  |             if await socket.poll(timeout=self._data_timeout) == 0: | ||||||
|  |                 raise TimeoutError("Server didn't reply within " | ||||||
|  |                                    f"{self._data_timeout} ms") | ||||||
|  |  | ||||||
|             # Await the data from the Server. |             # Await the data from the Server. | ||||||
|             data = cloudpickle.loads(await socket.recv()) |             data = cloudpickle.loads(await socket.recv()) | ||||||
|  |  | ||||||
|  |         if isinstance(data, Exception): | ||||||
|  |             # Re-raise exceptions returned by the server | ||||||
|  |             raise data | ||||||
|  |  | ||||||
|         if not isinstance(data, expected_type): |         if not isinstance(data, expected_type): | ||||||
|             # LoRAConfig can be None. |             # LoRAConfig can be None. | ||||||
|             if expected_type == LoRAConfig and data is None: |             if expected_type == LoRAConfig and data is None: | ||||||
| @ -208,29 +237,28 @@ class AsyncEngineRPCClient: | |||||||
|             self, |             self, | ||||||
|             request: RPC_REQUEST_TYPE, |             request: RPC_REQUEST_TYPE, | ||||||
|             error_message: str, |             error_message: str, | ||||||
|             timeout: Optional[int] = None, |  | ||||||
|             socket: Optional[zmq.asyncio.Socket] = None): |             socket: Optional[zmq.asyncio.Socket] = None): | ||||||
|         """Send one-way RPC request to trigger an action.""" |         """Send one-way RPC request to trigger an action.""" | ||||||
|  |  | ||||||
|         async def do_rpc_call(socket: zmq.asyncio.Socket, |         async def do_rpc_call(socket: zmq.asyncio.Socket, | ||||||
|                               request: RPC_REQUEST_TYPE, |                               request: RPC_REQUEST_TYPE): | ||||||
|                               timeout=None): |  | ||||||
|  |  | ||||||
|             await socket.send_multipart([cloudpickle.dumps(request)]) |             await socket.send_multipart([cloudpickle.dumps(request)]) | ||||||
|  |  | ||||||
|             if timeout is not None and await socket.poll(timeout=timeout) == 0: |             if await socket.poll(timeout=self._data_timeout) == 0: | ||||||
|                 raise TimeoutError(f"Server didn't reply within {timeout} ms") |                 raise TimeoutError("Server didn't reply within " | ||||||
|  |                                    f"{self._data_timeout} ms") | ||||||
|  |  | ||||||
|             return cloudpickle.loads(await socket.recv()) |             return cloudpickle.loads(await socket.recv()) | ||||||
|  |  | ||||||
|         # Make a new socket connection. |         # Make a new socket connection. | ||||||
|         if socket is None: |         if socket is None: | ||||||
|             with self.to_proxy_socket() as socket: |             with self.to_proxy_socket() as socket: | ||||||
|                 response = await do_rpc_call(socket, request, timeout) |                 response = await do_rpc_call(socket, request) | ||||||
|  |  | ||||||
|         # Use existing socket connection. |         # Use existing socket connection. | ||||||
|         else: |         else: | ||||||
|             response = await do_rpc_call(socket, request, timeout) |             response = await do_rpc_call(socket, request) | ||||||
|  |  | ||||||
|         if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: |         if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: | ||||||
|             if isinstance(response, Exception): |             if isinstance(response, Exception): | ||||||
| @ -255,8 +283,7 @@ class AsyncEngineRPCClient: | |||||||
|  |  | ||||||
|         await self._send_one_way_rpc_request( |         await self._send_one_way_rpc_request( | ||||||
|             request=RPCUtilityRequest.IS_SERVER_READY, |             request=RPCUtilityRequest.IS_SERVER_READY, | ||||||
|             error_message="Unable to start RPC Server", |             error_message="Unable to start RPC Server") | ||||||
|             timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS) |  | ||||||
|  |  | ||||||
|     async def _get_model_config_rpc(self) -> ModelConfig: |     async def _get_model_config_rpc(self) -> ModelConfig: | ||||||
|         """Get the ModelConfig object from the RPC Server""" |         """Get the ModelConfig object from the RPC Server""" | ||||||
| @ -308,17 +335,17 @@ class AsyncEngineRPCClient: | |||||||
|  |  | ||||||
|     async def abort(self, request_id: str): |     async def abort(self, request_id: str): | ||||||
|         """Send an ABORT_REQUEST signal to the RPC Server""" |         """Send an ABORT_REQUEST signal to the RPC Server""" | ||||||
|  |         with suppress(RPCClientClosedError): | ||||||
|         await self._send_one_way_rpc_request( |             await self._send_one_way_rpc_request( | ||||||
|             request=RPCAbortRequest(request_id), |                 request=RPCAbortRequest(request_id), | ||||||
|             error_message=f"RPCAbortRequest {request_id} failed") |                 error_message=f"RPCAbortRequest {request_id} failed") | ||||||
|  |  | ||||||
|     async def do_log_stats(self): |     async def do_log_stats(self): | ||||||
|         """Send a DO_LOG_STATS signal to the RPC Server""" |         """Send a DO_LOG_STATS signal to the RPC Server""" | ||||||
|  |         with suppress(RPCClientClosedError): | ||||||
|         await self._send_one_way_rpc_request( |             await self._send_one_way_rpc_request( | ||||||
|             request=RPCUtilityRequest.DO_LOG_STATS, |                 request=RPCUtilityRequest.DO_LOG_STATS, | ||||||
|             error_message="RPCRequest DO_LOG_STATS failed.") |                 error_message="RPCRequest DO_LOG_STATS failed.") | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_running(self) -> bool: |     def is_running(self) -> bool: | ||||||
| @ -393,7 +420,6 @@ class AsyncEngineRPCClient: | |||||||
|         await self._send_one_way_rpc_request( |         await self._send_one_way_rpc_request( | ||||||
|             request=RPCUtilityRequest.IS_SERVER_HEALTHY, |             request=RPCUtilityRequest.IS_SERVER_HEALTHY, | ||||||
|             error_message="Got Unhealthy response from RPC Server", |             error_message="Got Unhealthy response from RPC Server", | ||||||
|             timeout=VLLM_RPC_HEALTH_TIMEOUT_MS, |  | ||||||
|             socket=socket) |             socket=socket) | ||||||
|  |  | ||||||
|     async def encode(self, *args, |     async def encode(self, *args, | ||||||
|  | |||||||
| @ -56,6 +56,7 @@ if TYPE_CHECKING: | |||||||
|     VERBOSE: bool = False |     VERBOSE: bool = False | ||||||
|     VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False |     VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False | ||||||
|     VLLM_TEST_FORCE_FP8_MARLIN: bool = False |     VLLM_TEST_FORCE_FP8_MARLIN: bool = False | ||||||
|  |     VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 | ||||||
|     VLLM_ALLOW_ENGINE_USE_RAY: bool = False |     VLLM_ALLOW_ENGINE_USE_RAY: bool = False | ||||||
|     VLLM_PLUGINS: Optional[List[str]] = None |     VLLM_PLUGINS: Optional[List[str]] = None | ||||||
|     VLLM_TORCH_PROFILER_DIR: Optional[str] = None |     VLLM_TORCH_PROFILER_DIR: Optional[str] = None | ||||||
| @ -374,6 +375,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { | |||||||
|     (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in |     (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in | ||||||
|      ("1", "true")), |      ("1", "true")), | ||||||
|  |  | ||||||
|  |     # Time in ms for the zmq client to wait for a response from the backend | ||||||
|  |     # server for simple data operations | ||||||
|  |     "VLLM_RPC_GET_DATA_TIMEOUT_MS": | ||||||
|  |     lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), | ||||||
|  |  | ||||||
|     # If set, allow running the engine as a separate ray actor, |     # If set, allow running the engine as a separate ray actor, | ||||||
|     # which is a deprecated feature soon to be removed. |     # which is a deprecated feature soon to be removed. | ||||||
|     # See https://github.com/vllm-project/vllm/issues/7045 |     # See https://github.com/vllm-project/vllm/issues/7045 | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user