mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Perf][Frontend] eliminate api_key and x_request_id headers middleware overhead (#19946)
Signed-off-by: Yazan-Sharaya <yazan.sharaya.yes@gmail.com>
This commit is contained in:
@ -146,11 +146,6 @@ completion = client.chat.completions.create(
|
||||
Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
|
||||
with `--enable-request-id-headers`.
|
||||
|
||||
> Note that enablement of the headers can impact performance significantly at high QPS
|
||||
> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
|
||||
> rather than within the vLLM layer for this reason.
|
||||
> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details.
|
||||
|
||||
??? Code
|
||||
|
||||
```python
|
||||
|
116
tests/entrypoints/openai/test_optional_middleware.py
Normal file
116
tests/entrypoints/openai/test_optional_middleware.py
Normal file
@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for middleware that's off by default and can be toggled through
|
||||
server arguments, mainly --api-key and --enable-request-id-headers.
|
||||
"""
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# Use a small embeddings model for faster startup and smaller memory footprint.
|
||||
# Since we are not testing any chat functionality,
|
||||
# using a chat capable model is overkill.
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(request: pytest.FixtureRequest):
|
||||
passed_params = []
|
||||
if hasattr(request, "param"):
|
||||
passed_params = request.param
|
||||
if isinstance(passed_params, str):
|
||||
passed_params = [passed_params]
|
||||
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
*passed_params
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_api_token(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("v1/models"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_request_id_header(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"))
|
||||
assert "X-Request-Id" not in response.headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
[["--api-key", "test"]],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_token(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("v1/models"))
|
||||
assert response.status_code == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
[["--api-key", "test"]],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_passed_api_token(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("v1/models"),
|
||||
headers={"Authorization": "Bearer test"})
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
[["--api-key", "test"]],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_v1_api_token(server: RemoteOpenAIServer):
|
||||
# Authorization check is skipped for any paths that
|
||||
# don't start with /v1 (e.g. /v1/chat/completions).
|
||||
response = requests.get(server.url_for("health"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
["--enable-request-id-headers"],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_request_id_header(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"))
|
||||
assert "X-Request-Id" in response.headers
|
||||
assert len(response.headers.get("X-Request-Id", "")) == 32
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
["--enable-request-id-headers"],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_request_id_header(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"),
|
||||
headers={"X-Request-Id": "Custom"})
|
||||
assert "X-Request-Id" in response.headers
|
||||
assert response.headers.get("X-Request-Id") == "Custom"
|
@ -14,7 +14,7 @@ import socket
|
||||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
@ -30,8 +30,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import State
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||
from starlette.routing import Mount
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization header exists and equals "Bearer {api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are two cases in which authentication is skipped:
|
||||
1. The HTTP method is OPTIONS.
|
||||
2. The request path doesn't start with /v1 (e.g. /health).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, api_token: str) -> None:
|
||||
self.app = app
|
||||
self.api_token = api_token
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http",
|
||||
"websocket") or scope["method"] == "OPTIONS":
|
||||
# scope["type"] can be "lifespan" or "startup" for example,
|
||||
# in which case we don't need to do anything
|
||||
return self.app(scope, receive, send)
|
||||
root_path = scope.get("root_path", "")
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and headers.get(
|
||||
"Authorization") != f"Bearer {self.api_token}":
|
||||
response = JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return response(scope, receive, send)
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
class XRequestIdMiddleware:
|
||||
"""
|
||||
Middleware the set's the X-Request-Id header for each response
|
||||
to a random uuid4 (hex) value if the header isn't already
|
||||
present in the request, otherwise use the provided request id.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Extract the request headers.
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
async def send_with_request_id(message: Message) -> None:
|
||||
"""
|
||||
Custom send function to mutate the response headers
|
||||
and append X-Request-Id to it.
|
||||
"""
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = MutableHeaders(raw=message["headers"])
|
||||
request_id = request_headers.get("X-Request-Id",
|
||||
uuid.uuid4().hex)
|
||||
response_headers.append("X-Request-Id", request_id)
|
||||
await send(message)
|
||||
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
@ -1108,33 +1177,10 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
||||
if token := args.api_key or envs.VLLM_API_KEY:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
url_path = request.url.path
|
||||
if app.root_path and url_path.startswith(app.root_path):
|
||||
url_path = url_path[len(app.root_path):]
|
||||
if not url_path.startswith("/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
app.add_middleware(AuthenticationMiddleware, api_token=token)
|
||||
|
||||
if args.enable_request_id_headers:
|
||||
logger.warning(
|
||||
"CAUTION: Enabling X-Request-Id headers in the API Server. "
|
||||
"This can harm performance at high QPS.")
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_request_id(request: Request, call_next):
|
||||
request_id = request.headers.get(
|
||||
"X-Request-Id") or uuid.uuid4().hex
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
app.add_middleware(XRequestIdMiddleware)
|
||||
|
||||
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||
logger.warning("CAUTION: Enabling log response in the API Server. "
|
||||
|
@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"--enable-request-id-headers",
|
||||
action="store_true",
|
||||
help="If specified, API server will add X-Request-Id header to "
|
||||
"responses. Caution: this hurts performance at high QPS.")
|
||||
"responses.")
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
|
Reference in New Issue
Block a user