[gpt-oss][1][bugfix] fix streaming final output (#24466)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia
2025-09-16 12:56:16 -07:00
committed by GitHub
parent dcf2f3ec06
commit 86daa875fe
3 changed files with 91 additions and 5 deletions

View File

@ -364,6 +364,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
events.append(event)
assert len(events) > 0
response_completed_event = events[-1]
assert len(response_completed_event.response.output) > 0
if background:
starting_after = 5

View File

@ -4,7 +4,7 @@
from unittest.mock import MagicMock, patch
import pytest
from openai_harmony import StreamState
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
from vllm.outputs import CompletionOutput, RequestOutput
@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case():
@pytest.mark.asyncio
async def test_streaming_multi_turn_token_counting(mock_parser):
"""Test token counting for streaming multi-turn conversations.
This test focuses on how StreamingHarmonyContext counts tokens in a
multi-turn conversation with streaming (token-by-token) outputs and
This test focuses on how StreamingHarmonyContext counts tokens in a
multi-turn conversation with streaming (token-by-token) outputs and
message boundaries.
"""
# Create a streaming context
@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
additional_tool_tokens = 13 - 8 - 3 # = 2
assert context.num_tool_output_tokens == expected_tool_tokens \
+ additional_tool_tokens
@pytest.mark.asyncio
async def test_streaming_message_synchronization(mock_parser):
"""Test message synchronization logic from lines 413-417 in context.py.
This test verifies that when parser.messages contains more messages than
the context's _messages (minus initial messages), the context properly
extends its message list with the new parser messages.
"""
# Create a streaming context with some initial messages
initial_messages = [
Message(
author=Author(role=Role.USER, name="user"),
content=[TextContent(text="Hello")],
recipient=Role.ASSISTANT,
)
]
context = StreamingHarmonyContext(messages=initial_messages,
available_tools=[])
# Verify initial state
assert len(context._messages) == 1
assert context.num_init_messages == 1
# Mock parser to have more messages than context
# Simulate parser having processed 3 new messages
mock_parser.messages = [
Message(
author=Author(role=Role.ASSISTANT, name="assistant"),
content=[TextContent(text="Response 1")],
recipient=Role.USER,
),
]
# This should trigger the message synchronization logic
context.append_output(
create_mock_request_output(prompt_token_ids=[1, 2, 3],
output_token_ids=[101],
finished=False))
# Verify that messages were synchronized
assert len(context._messages) == 2
# Verify the new messages were added correctly
assert context._messages[1].content[0].text == "Response 1"
# Test the specific condition from line 413-414:
# len(self._messages) - self.num_init_messages < len(self.parser.messages)
messages_minus_init = len(context._messages) - context.num_init_messages
parser_messages_count = len(mock_parser.messages)
# After synchronization, they should be equal (no longer less than)
assert messages_minus_init == parser_messages_count
# Test edge case: add one more parser message
mock_parser.messages.append(
Message(
author=Author(role=Role.ASSISTANT, name="assistant"),
content=[TextContent(text="Response 4")],
recipient=Role.USER,
))
# Create another output to trigger synchronization again
mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3],
output_token_ids=[102],
finished=True)
context.append_output(mock_output2)
# Verify the fourth message was added, num_init_messages is still 1
assert len(context._messages) == 3
assert context.num_init_messages == 1
assert context._messages[2].content[0].text == "Response 4"

View File

@ -151,6 +151,9 @@ class HarmonyContext(ConversationContext):
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
# append_output is called only once before tool calling
# in non-streaming case
# so we can append all the parser messages to _messages
output_msgs = self.parser.messages
# The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason
@ -387,7 +390,7 @@ class StreamingHarmonyContext(HarmonyContext):
@property
def messages(self) -> list:
return self.parser.messages
return self._messages
def append_output(self, output: Union[RequestOutput,
list[Message]]) -> None:
@ -412,6 +415,11 @@ class StreamingHarmonyContext(HarmonyContext):
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok
if len(self._messages) - self.num_init_messages < len(
self.parser.messages):
self._messages.extend(
self.parser.messages[len(self._messages) -
self.num_init_messages:])
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
@ -424,6 +432,7 @@ class StreamingHarmonyContext(HarmonyContext):
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages
def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START