feat: Add --enable-log-outputs flag for logging model generations (#20707)

Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai>
This commit is contained in:
Adrián García García
2025-08-07 10:10:13 +04:00
committed by GitHub
parent 82216dc21f
commit 8e8e0b6af1
5 changed files with 412 additions and 26 deletions

View File

@ -10,11 +10,12 @@ from dataclasses import dataclass
from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile
from typing import Any
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from vllm.entrypoints.logger import RequestLogger
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
enable_trace_function_call, init_logger)
from vllm.logging_utils import NewLineFormatter
@ -228,9 +229,10 @@ def test_prepare_object_to_dump():
list_obj = [1, 2, 3]
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'
dict_obj = {'a': 1, 'b': 'b'}
dict_obj = {"a": 1, "b": "b"}
assert prepare_object_to_dump(dict_obj) in [
"{a: 1, b: 'b'}", "{b: 'b', a: 1}"
"{a: 1, b: 'b'}",
"{b: 'b', a: 1}",
]
set_obj = {1, 2, 3}
@ -252,4 +254,246 @@ def test_prepare_object_to_dump():
b: str
assert (prepare_object_to_dump(CustomClass(
1, 'b')) == "CustomClass(a=1, b='b')")
1, "b")) == "CustomClass(a=1, b='b')")
def test_request_logger_log_outputs():
"""Test the new log_outputs functionality."""
# Create a mock logger to capture log calls
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test basic output logging
request_logger.log_outputs(
request_id="test-123",
outputs="Hello, world!",
output_token_ids=[1, 2, 3, 4],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-123"
assert call_args[3] == "Hello, world!"
assert call_args[4] == [1, 2, 3, 4]
assert call_args[5] == "stop"
def test_request_logger_log_outputs_streaming_delta():
"""Test log_outputs with streaming delta mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming delta logging
request_logger.log_outputs(
request_id="test-456",
outputs="Hello",
output_token_ids=[1],
finish_reason=None,
is_streaming=True,
delta=True,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-456"
assert call_args[2] == " (streaming delta)"
assert call_args[3] == "Hello"
assert call_args[4] == [1]
assert call_args[5] is None
def test_request_logger_log_outputs_streaming_complete():
"""Test log_outputs with streaming complete mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming complete logging
request_logger.log_outputs(
request_id="test-789",
outputs="Complete response",
output_token_ids=[1, 2, 3],
finish_reason="length",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-789"
assert call_args[2] == " (streaming complete)"
assert call_args[3] == "Complete response"
assert call_args[4] == [1, 2, 3]
assert call_args[5] == "length"
def test_request_logger_log_outputs_with_truncation():
"""Test log_outputs respects max_log_len setting."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
# Set max_log_len to 10
request_logger = RequestLogger(max_log_len=10)
# Test output truncation
long_output = "This is a very long output that should be truncated"
long_token_ids = list(range(20)) # 20 tokens
request_logger.log_outputs(
request_id="test-truncate",
outputs=long_output,
output_token_ids=long_token_ids,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
# Check that output was truncated to first 10 characters
logged_output = call_args[0][3]
assert logged_output == "This is a "
assert len(logged_output) == 10
# Check that token IDs were truncated to first 10 tokens
logged_token_ids = call_args[0][4]
assert logged_token_ids == list(range(10))
assert len(logged_token_ids) == 10
def test_request_logger_log_outputs_none_values():
"""Test log_outputs handles None values correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with None output_token_ids
request_logger.log_outputs(
request_id="test-none",
outputs="Test output",
output_token_ids=None,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-none"
assert call_args[3] == "Test output"
assert call_args[4] is None
assert call_args[5] == "stop"
def test_request_logger_log_outputs_empty_output():
"""Test log_outputs handles empty output correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=5)
# Test with empty output
request_logger.log_outputs(
request_id="test-empty",
outputs="",
output_token_ids=[],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-empty"
assert call_args[3] == ""
assert call_args[4] == []
assert call_args[5] == "stop"
def test_request_logger_log_outputs_integration():
"""Test that log_outputs can be called alongside log_inputs."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test that both methods can be called without interference
request_logger.log_inputs(
request_id="test-integration",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_embeds=None,
params=None,
lora_request=None,
)
request_logger.log_outputs(
request_id="test-integration",
outputs="Test output",
output_token_ids=[4, 5, 6],
finish_reason="stop",
is_streaming=False,
delta=False,
)
# Should have been called twice - once for inputs, once for outputs
assert mock_logger.info.call_count == 2
# Check that the calls were made with correct patterns
input_call = mock_logger.info.call_args_list[0][0]
output_call = mock_logger.info.call_args_list[1][0]
assert "Received request %s" in input_call[0]
assert input_call[1] == "test-integration"
assert "Generated response %s%s" in output_call[0]
assert output_call[1] == "test-integration"
def test_streaming_complete_logs_full_text_content():
"""Test that streaming complete logging includes
full accumulated text, not just token count."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with actual content instead of token count format
full_response = "This is a complete response from streaming"
request_logger.log_outputs(
request_id="test-streaming-full-text",
outputs=full_response,
output_token_ids=None,
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
# Verify the logged output is the full text, not a token count format
logged_output = call_args[3]
assert logged_output == full_response
assert "tokens>" not in logged_output
assert "streaming_complete" not in logged_output
# Verify other parameters
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
import torch
@ -16,8 +17,6 @@ logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
super().__init__()
self.max_log_len = max_log_len
def log_inputs(
@ -45,3 +44,36 @@ class RequestLogger:
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
lora_request)
def log_outputs(
self,
request_id: str,
outputs: str,
output_token_ids: Optional[Sequence[int]],
finish_reason: Optional[str] = None,
is_streaming: bool = False,
delta: bool = False,
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if outputs is not None:
outputs = outputs[:max_log_len]
if output_token_ids is not None:
# Convert to list and apply truncation
output_token_ids = list(output_token_ids)[:max_log_len]
stream_info = ""
if is_streaming:
stream_info = (" (streaming delta)"
if delta else " (streaming complete)")
logger.info(
"Generated response %s%s: output: %r, "
"output_token_ids: %s, finish_reason: %s",
request_id,
stream_info,
outputs,
output_token_ids,
finish_reason,
)

View File

@ -44,10 +44,10 @@ class LoRAParserAction(argparse.Action):
lora_list: list[LoRAModulePath] = []
for item in values:
if item in [None, '']: # Skip if item is None or empty string
if item in [None, ""]: # Skip if item is None or empty string
continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
if "=" in item and "," not in item: # Old format: name=path
name, path = item.split("=")
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
@ -167,6 +167,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If set to True, enable logging of model outputs (generations)
in addition to the input logging that is enabled by default."""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@ -73,6 +73,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
@ -84,6 +85,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.enable_log_outputs = enable_log_outputs
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
@ -489,20 +491,21 @@ class OpenAIServingChat(OpenAIServing):
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned = [False] * num_choices
# Always track previous_texts for comprehensive output logging
previous_texts = [""] * num_choices
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or self.reasoning_parser:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
all_previous_token_ids = None
try:
if self.reasoning_parser:
@ -844,6 +847,7 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids,
request=request))
# when only reasoning
elif self.reasoning_parser:
delta_message = (reasoning_parser.
@ -865,6 +869,10 @@ class OpenAIServingChat(OpenAIServing):
assert all_previous_token_ids is not None
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
else:
# Update for comprehensive logging even in simple case
assert previous_texts is not None
previous_texts[i] += delta_text
# set the previous values for the next iteration
previous_num_tokens[i] += len(output.token_ids)
@ -876,6 +884,27 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is None:
continue
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
delta_content = ""
if delta_message.content:
delta_content = delta_message.content
elif delta_message.tool_calls:
delta_content = "".join(
tc.function.arguments
for tc in delta_message.tool_calls
if tc.function and tc.function.arguments)
if delta_content:
self.request_logger.log_outputs(
request_id=request_id,
outputs=delta_content,
output_token_ids=list(output.token_ids),
finish_reason=output.finish_reason,
is_streaming=True,
delta=True,
)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
@ -994,7 +1023,27 @@ class OpenAIServingChat(OpenAIServing):
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens)
total_tokens=num_prompt_tokens + num_completion_tokens,
)
# Log complete streaming response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
# Log the complete response for each choice
for i in range(num_choices):
full_text = (
previous_texts[i]
if previous_texts and i < len(previous_texts) else
f"<streaming_complete: {previous_num_tokens[i]} tokens>"
)
self.request_logger.log_outputs(
request_id=request_id,
outputs=full_text,
output_token_ids=
None, # Consider also logging all token IDs
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
@ -1121,8 +1170,10 @@ class OpenAIServingChat(OpenAIServing):
tool_calls=[
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=content))
])
arguments=content,
))
],
)
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
@ -1209,12 +1260,13 @@ class OpenAIServingChat(OpenAIServing):
finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo:
last_msg_content: Union[str, list[dict[str, str]]] = ""
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
if (conversation and "content" in conversation[-1]
and conversation[-1].get("role") == role):
last_msg_content = conversation[-1]["content"] or ""
if isinstance(last_msg_content, list):
last_msg_content = "\n".join(msg['text']
@ -1251,6 +1303,40 @@ class OpenAIServingChat(OpenAIServing):
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
output_text = ""
if choice.message.content:
output_text = choice.message.content
elif choice.message.tool_calls:
# For tool calls, log the function name and arguments
tool_call_descriptions = []
for tool_call in choice.message.tool_calls:
if hasattr(tool_call.function, "name") and hasattr(
tool_call.function, "arguments"):
tool_call_descriptions.append(
f"{tool_call.function.name}({tool_call.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]"
if output_text:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[
choice.index].token_ids
self.request_logger.log_outputs(
request_id=request_id,
outputs=output_text,
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response
def _get_top_logprobs(
@ -1258,15 +1344,16 @@ class OpenAIServingChat(OpenAIServing):
tokenizer: AnyTokenizer,
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=should_return_as_token_id)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
ChatCompletionLogProb(
token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace")),
) for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]

View File

@ -65,6 +65,7 @@ class OpenAIServingResponses(OpenAIServing):
tool_server: Optional[ToolServer] = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@ -77,6 +78,7 @@ class OpenAIServingResponses(OpenAIServing):
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.enable_log_outputs = enable_log_outputs
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
ReasoningParser]] = None
@ -428,6 +430,24 @@ class OpenAIServingResponses(OpenAIServing):
usage=usage,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
output_text = ""
if content:
output_text = content
elif reasoning_content:
output_text = f"[reasoning: {reasoning_content}]"
if output_text:
self.request_logger.log_outputs(
request_id=request.request_id,
outputs=output_text,
output_token_ids=final_output.token_ids,
finish_reason=final_output.finish_reason,
is_streaming=False,
delta=False,
)
if request.store:
async with self.response_store_lock:
stored_response = self.response_store.get(response.id)