[Frontend] run-batch supports V1 (#21541)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-25 11:05:55 +08:00
committed by GitHub
parent fe56180c7f
commit 34ddcf9ff4
5 changed files with 56 additions and 25 deletions

View File

@ -167,7 +167,8 @@ async def run_vllm_async(
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing
engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = await llm.get_model_config()
assert all(

View File

@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
def test_metrics_exist_run_batch(use_v1: bool):
if use_v1:
pytest.skip("Skipping test on vllm V1")
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
base_url = "0.0.0.0"
@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool):
base_url,
"--port",
port,
], )
],
env={"VLLM_USE_V1": "1" if use_v1 else "0"})
def is_server_up(url):
try:

View File

@ -148,7 +148,9 @@ async def run_vllm_async(
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = await llm.get_model_config()
assert all(
model_config.max_model_len >= (request.prompt_len +

View File

@ -149,6 +149,9 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: Optional[bool] = None,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
@ -156,15 +159,24 @@ async def build_async_engine_client(
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(
args.disable_frontend_multiprocessing)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing,
client_config) as engine:
engine_args,
usage_context=usage_context,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
client_config=client_config,
) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args(
"""
# Create the EngineConfig (determines if we can use V1).
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM.
@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address,
if log_config is not None:
uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client:
async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)

View File

@ -3,6 +3,7 @@
import asyncio
import tempfile
from argparse import Namespace
from collections.abc import Awaitable
from http import HTTPStatus
from io import StringIO
@ -13,10 +14,12 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable,
return batch_output
async def main(args):
async def run_batch(
engine_client: EngineClient,
vllm_config: VllmConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
model_config = vllm_config.model_config
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
)
openai_serving_chat = OpenAIServingChat(
engine,
engine_client,
model_config,
openai_serving_models,
args.response_role,
@ -349,7 +353,7 @@ async def main(args):
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
@ -362,7 +366,7 @@ async def main(args):
"num_labels", 0) == 1)
openai_serving_scores = ServingScores(
engine,
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
@ -457,6 +461,17 @@ async def main(args):
await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
vllm_config = await engine_client.get_vllm_config()
await run_batch(engine_client, vllm_config, args)
if __name__ == "__main__":
args = parse_args()