[Bugfix][FE]: Always include usage with --enable-force-include-usage (#20983)

Signed-off-by: Max Wittig <max.wittig@siemens.com>
Signed-off-by: Antoine Auger <antoineauger@users.noreply.github.com>
Co-authored-by: Antoine Auger <antoineauger@users.noreply.github.com>
This commit is contained in:
Max Wittig
2025-10-14 09:17:39 +02:00
committed by GitHub
parent d32c611f45
commit fd85c9f426
11 changed files with 172 additions and 30 deletions

View File

@ -107,6 +107,7 @@ markers = [
"distributed: run this test only in distributed GPU tests",
"skip_v1: do not run this test with v1",
"optional: optional tests that are automatically skipped, include --optional to run them",
"extra_server_args: extra arguments to pass to the server fixture",
]
[tool.ty.src]

View File

@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai
import pytest
import pytest_asyncio
from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="module")
def chat_server_with_force_include_usage(request): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"128",
"--enforce-eager",
"--max-num-seqs",
"1",
"--enable-force-include-usage",
"--port",
"55857",
"--gpu-memory-utilization",
"0.2",
]
with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def chat_client_with_force_include_usage(chat_server_with_force_include_usage):
async with chat_server_with_force_include_usage.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_chat_with_enable_force_include_usage(
chat_client_with_force_include_usage: openai.AsyncOpenAI,
):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
stream = await chat_client_with_force_include_usage.chat.completions.create(
model="Qwen/Qwen3-0.6B",
messages=messages,
max_completion_tokens=10,
extra_body=dict(min_tokens=10),
temperature=0.0,
stream=True,
)
last_completion_tokens = 0
async for chunk in stream:
if not len(chunk.choices):
assert chunk.usage.prompt_tokens >= 0
assert (
last_completion_tokens == 0
or chunk.usage.completion_tokens > last_completion_tokens
or (
not chunk.choices
and chunk.usage.completion_tokens == last_completion_tokens
)
)
assert chunk.usage.total_tokens == (
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
)
else:
assert chunk.usage is None
@pytest.fixture(scope="module")
def transcription_server_with_force_include_usage():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-num-seqs",
"1",
"--enforce-eager",
"--enable-force-include-usage",
"--gpu-memory-utilization",
"0.2",
]
with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def transcription_client_with_force_include_usage(
transcription_server_with_force_include_usage,
):
async with (
transcription_server_with_force_include_usage.get_async_client() as async_client
):
yield async_client
@pytest.mark.asyncio
async def test_transcription_with_enable_force_include_usage(
transcription_client_with_force_include_usage, winning_call
):
res = (
await transcription_client_with_force_include_usage.audio.transcriptions.create(
model="openai/whisper-large-v3-turbo",
file=winning_call,
language="en",
temperature=0.0,
stream=True,
timeout=30,
)
)
async for chunk in res:
if not len(chunk.choices):
# final usage sent
usage = chunk.usage
assert isinstance(usage, dict)
assert usage["prompt_tokens"] > 0
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] > 0
else:
assert not hasattr(chunk, "usage")

View File

@ -1808,6 +1808,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
@ -1818,6 +1819,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None

View File

@ -104,6 +104,13 @@ def make_arg_parser(parser: FlexibleArgumentParser):
default=False,
help="If set to True, enable prompt_tokens_details in usage.",
)
parser.add_argument(
"--enable-force-include-usage",
action="store_true",
default=False,
help="If set to True, include usage on every request "
"(even when stream_options is not specified)",
)
return parser
@ -361,6 +368,7 @@ async def run_batch(
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None

View File

@ -58,7 +58,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.utils import get_max_tokens
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
@ -101,7 +101,6 @@ class OpenAIServingChat(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack,
)
@ -352,7 +351,6 @@ class OpenAIServingChat(OpenAIServing):
conversation,
tokenizer,
request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
)
try:
@ -518,7 +516,6 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
@ -596,13 +593,9 @@ class OpenAIServingChat(OpenAIServing):
return
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
include_continuous_usage = (
include_usage and stream_options.continuous_usage_stats
)
else:
include_usage, include_continuous_usage = False, False
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for res in result_generator:

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
@ -56,11 +56,11 @@ class OpenAIServingCompletion(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
@ -256,7 +256,6 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response
@ -320,7 +319,6 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
@ -331,13 +329,9 @@ class OpenAIServingCompletion(OpenAIServing):
first_iteration = True
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
include_continuous_usage = (
include_usage and stream_options.continuous_usage_stats
)
else:
include_usage, include_continuous_usage = False, False
include_usage, include_continuous_usage = should_include_usage(
stream_options, self.enable_force_include_usage
)
try:
async for prompt_idx, res in result_generator:

View File

@ -249,7 +249,6 @@ class OpenAIServing:
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__()
@ -260,8 +259,6 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._apply_mistral_chat_template_async = make_async(
apply_mistral_chat_template, executor=self._tokenizer_executor

View File

@ -127,7 +127,6 @@ class OpenAIServingResponses(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
log_error_stack=log_error_stack,
)

View File

@ -37,6 +37,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
@ -45,6 +46,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_transcription(
@ -96,6 +98,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
@ -104,6 +107,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(

View File

@ -58,6 +58,7 @@ class OpenAISpeechToText(OpenAIServing):
return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
@ -74,6 +75,8 @@ class OpenAISpeechToText(OpenAIServing):
self.model_config, task_type
)
self.enable_force_include_usage = enable_force_include_usage
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
if self.default_sampling_params:
@ -261,9 +264,7 @@ class OpenAISpeechToText(OpenAIServing):
completion_tokens = 0
num_prompt_tokens = 0
include_usage = (
request.stream_include_usage if request.stream_include_usage else False
)
include_usage = self.enable_force_include_usage or request.stream_include_usage
include_continuous_usage = (
request.stream_continuous_usage_stats
if include_usage and request.stream_continuous_usage_stats

View File

@ -14,7 +14,11 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
@ -237,3 +241,16 @@ def log_non_default_args(args: Namespace | EngineArgs):
)
logger.info("non-default args: %s", non_default_args)
def should_include_usage(
stream_options: StreamOptions | None, enable_force_include_usage: bool
) -> tuple[bool, bool]:
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
include_continuous_usage = include_usage and bool(
stream_options.continuous_usage_stats
)
else:
include_usage, include_continuous_usage = enable_force_include_usage, False
return include_usage, include_continuous_usage