[Perf][Frontend] eliminate api_key and x_request_id headers middleware overhead (#19946)

Signed-off-by: Yazan-Sharaya <yazan.sharaya.yes@gmail.com>
This commit is contained in:
Yazan Sharaya
2025-06-27 07:44:14 +03:00
committed by GitHub
parent cd4cfee689
commit 6e244ae091
4 changed files with 190 additions and 33 deletions

View File

@ -146,11 +146,6 @@ completion = client.chat.completions.create(
Only `X-Request-Id` HTTP request header is supported for now. It can be enabled Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
with `--enable-request-id-headers`. with `--enable-request-id-headers`.
> Note that enablement of the headers can impact performance significantly at high QPS
> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
> rather than within the vLLM layer for this reason.
> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details.
??? Code ??? Code
```python ```python

View File

@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for middleware that's off by default and can be toggled through
server arguments, mainly --api-key and --enable-request-id-headers.
"""
from http import HTTPStatus
import pytest
import requests
from ...utils import RemoteOpenAIServer
# Use a small embeddings model for faster startup and smaller memory footprint.
# Since we are not testing any chat functionality,
# using a chat capable model is overkill.
MODEL_NAME = "intfloat/multilingual-e5-small"
@pytest.fixture(scope="module")
def server(request: pytest.FixtureRequest):
passed_params = []
if hasattr(request, "param"):
passed_params = request.param
if isinstance(passed_params, str):
passed_params = [passed_params]
args = [
"--task",
"embed",
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"512",
"--enforce-eager",
"--max-num-seqs",
"2",
*passed_params
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_no_api_token(server: RemoteOpenAIServer):
response = requests.get(server.url_for("v1/models"))
assert response.status_code == HTTPStatus.OK
@pytest.mark.asyncio
async def test_no_request_id_header(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"))
assert "X-Request-Id" not in response.headers
@pytest.mark.parametrize(
"server",
[["--api-key", "test"]],
indirect=True,
)
@pytest.mark.asyncio
async def test_missing_api_token(server: RemoteOpenAIServer):
response = requests.get(server.url_for("v1/models"))
assert response.status_code == HTTPStatus.UNAUTHORIZED
@pytest.mark.parametrize(
"server",
[["--api-key", "test"]],
indirect=True,
)
@pytest.mark.asyncio
async def test_passed_api_token(server: RemoteOpenAIServer):
response = requests.get(server.url_for("v1/models"),
headers={"Authorization": "Bearer test"})
assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize(
"server",
[["--api-key", "test"]],
indirect=True,
)
@pytest.mark.asyncio
async def test_not_v1_api_token(server: RemoteOpenAIServer):
# Authorization check is skipped for any paths that
# don't start with /v1 (e.g. /v1/chat/completions).
response = requests.get(server.url_for("health"))
assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize(
"server",
["--enable-request-id-headers"],
indirect=True,
)
@pytest.mark.asyncio
async def test_enable_request_id_header(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"))
assert "X-Request-Id" in response.headers
assert len(response.headers.get("X-Request-Id", "")) == 32
@pytest.mark.parametrize(
"server",
["--enable-request-id-headers"],
indirect=True,
)
@pytest.mark.asyncio
async def test_custom_request_id_header(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"),
headers={"X-Request-Id": "Custom"})
assert "X-Request-Id" in response.headers
assert response.headers.get("X-Request-Id") == "Custom"

View File

@ -14,7 +14,7 @@ import socket
import tempfile import tempfile
import uuid import uuid
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Awaitable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
@ -30,8 +30,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.routing import Mount from starlette.routing import Mount
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from typing_extensions import assert_never from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
return None return None
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}".
Notes
-----
There are two cases in which authentication is skipped:
1. The HTTP method is OPTIONS.
2. The request path doesn't start with /v1 (e.g. /health).
"""
def __init__(self, app: ASGIApp, api_token: str) -> None:
self.app = app
self.api_token = api_token
def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
if scope["type"] not in ("http",
"websocket") or scope["method"] == "OPTIONS":
# scope["type"] can be "lifespan" or "startup" for example,
# in which case we don't need to do anything
return self.app(scope, receive, send)
root_path = scope.get("root_path", "")
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get(
"Authorization") != f"Bearer {self.api_token}":
response = JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return response(scope, receive, send)
return self.app(scope, receive, send)
class XRequestIdMiddleware:
"""
Middleware the set's the X-Request-Id header for each response
to a random uuid4 (hex) value if the header isn't already
present in the request, otherwise use the provided request id.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket"):
return self.app(scope, receive, send)
# Extract the request headers.
request_headers = Headers(scope=scope)
async def send_with_request_id(message: Message) -> None:
"""
Custom send function to mutate the response headers
and append X-Request-Id to it.
"""
if message["type"] == "http.response.start":
response_headers = MutableHeaders(raw=message["headers"])
request_id = request_headers.get("X-Request-Id",
uuid.uuid4().hex)
response_headers.append("X-Request-Id", request_id)
await send(message)
return self.app(scope, receive, send_with_request_id)
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs: if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None, app = FastAPI(openapi_url=None,
@ -1108,33 +1177,10 @@ def build_app(args: Namespace) -> FastAPI:
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if token := args.api_key or envs.VLLM_API_KEY: if token := args.api_key or envs.VLLM_API_KEY:
app.add_middleware(AuthenticationMiddleware, api_token=token)
@app.middleware("http")
async def authentication(request: Request, call_next):
if request.method == "OPTIONS":
return await call_next(request)
url_path = request.url.path
if app.root_path and url_path.startswith(app.root_path):
url_path = url_path[len(app.root_path):]
if not url_path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
if args.enable_request_id_headers: if args.enable_request_id_headers:
logger.warning( app.add_middleware(XRequestIdMiddleware)
"CAUTION: Enabling X-Request-Id headers in the API Server. "
"This can harm performance at high QPS.")
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
response = await call_next(request)
response.headers["X-Request-Id"] = request_id
return response
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. " logger.warning("CAUTION: Enabling log response in the API Server. "

View File

@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--enable-request-id-headers", "--enable-request-id-headers",
action="store_true", action="store_true",
help="If specified, API server will add X-Request-Id header to " help="If specified, API server will add X-Request-Id header to "
"responses. Caution: this hurts performance at high QPS.") "responses.")
parser.add_argument( parser.add_argument(
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
action="store_true", action="store_true",