mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[gpt-oss][1][bugfix] fix streaming final output (#24466)
Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
@ -364,6 +364,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) > 0
|
||||
response_completed_event = events[-1]
|
||||
assert len(response_completed_event.response.output) > 0
|
||||
|
||||
if background:
|
||||
starting_after = 5
|
||||
|
@ -4,7 +4,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai_harmony import StreamState
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case():
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
"""Test token counting for streaming multi-turn conversations.
|
||||
|
||||
This test focuses on how StreamingHarmonyContext counts tokens in a
|
||||
multi-turn conversation with streaming (token-by-token) outputs and
|
||||
|
||||
This test focuses on how StreamingHarmonyContext counts tokens in a
|
||||
multi-turn conversation with streaming (token-by-token) outputs and
|
||||
message boundaries.
|
||||
"""
|
||||
# Create a streaming context
|
||||
@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
additional_tool_tokens = 13 - 8 - 3 # = 2
|
||||
assert context.num_tool_output_tokens == expected_tool_tokens \
|
||||
+ additional_tool_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_message_synchronization(mock_parser):
|
||||
"""Test message synchronization logic from lines 413-417 in context.py.
|
||||
|
||||
This test verifies that when parser.messages contains more messages than
|
||||
the context's _messages (minus initial messages), the context properly
|
||||
extends its message list with the new parser messages.
|
||||
"""
|
||||
|
||||
# Create a streaming context with some initial messages
|
||||
initial_messages = [
|
||||
Message(
|
||||
author=Author(role=Role.USER, name="user"),
|
||||
content=[TextContent(text="Hello")],
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
context = StreamingHarmonyContext(messages=initial_messages,
|
||||
available_tools=[])
|
||||
|
||||
# Verify initial state
|
||||
assert len(context._messages) == 1
|
||||
assert context.num_init_messages == 1
|
||||
|
||||
# Mock parser to have more messages than context
|
||||
# Simulate parser having processed 3 new messages
|
||||
mock_parser.messages = [
|
||||
Message(
|
||||
author=Author(role=Role.ASSISTANT, name="assistant"),
|
||||
content=[TextContent(text="Response 1")],
|
||||
recipient=Role.USER,
|
||||
),
|
||||
]
|
||||
|
||||
# This should trigger the message synchronization logic
|
||||
context.append_output(
|
||||
create_mock_request_output(prompt_token_ids=[1, 2, 3],
|
||||
output_token_ids=[101],
|
||||
finished=False))
|
||||
|
||||
# Verify that messages were synchronized
|
||||
assert len(context._messages) == 2
|
||||
|
||||
# Verify the new messages were added correctly
|
||||
assert context._messages[1].content[0].text == "Response 1"
|
||||
|
||||
# Test the specific condition from line 413-414:
|
||||
# len(self._messages) - self.num_init_messages < len(self.parser.messages)
|
||||
messages_minus_init = len(context._messages) - context.num_init_messages
|
||||
parser_messages_count = len(mock_parser.messages)
|
||||
|
||||
# After synchronization, they should be equal (no longer less than)
|
||||
assert messages_minus_init == parser_messages_count
|
||||
|
||||
# Test edge case: add one more parser message
|
||||
mock_parser.messages.append(
|
||||
Message(
|
||||
author=Author(role=Role.ASSISTANT, name="assistant"),
|
||||
content=[TextContent(text="Response 4")],
|
||||
recipient=Role.USER,
|
||||
))
|
||||
|
||||
# Create another output to trigger synchronization again
|
||||
mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3],
|
||||
output_token_ids=[102],
|
||||
finished=True)
|
||||
|
||||
context.append_output(mock_output2)
|
||||
|
||||
# Verify the fourth message was added, num_init_messages is still 1
|
||||
assert len(context._messages) == 3
|
||||
assert context.num_init_messages == 1
|
||||
assert context._messages[2].content[0].text == "Response 4"
|
||||
|
@ -151,6 +151,9 @@ class HarmonyContext(ConversationContext):
|
||||
self._update_decode_token_usage(output)
|
||||
# Move current turn to previous turn for next turn's calculations
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
output_msgs = self.parser.messages
|
||||
# The responses finish reason is set in the last message
|
||||
self.finish_reason = output.outputs[0].finish_reason
|
||||
@ -387,7 +390,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self.parser.messages
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: Union[RequestOutput,
|
||||
list[Message]]) -> None:
|
||||
@ -412,6 +415,11 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(
|
||||
self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) -
|
||||
self.num_init_messages:])
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
@ -424,6 +432,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
for tok in toks:
|
||||
self.parser.process(tok)
|
||||
self.last_tok = toks[-1]
|
||||
# TODO: add tool_output messages to self._messages
|
||||
|
||||
def is_expecting_start(self) -> bool:
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
Reference in New Issue
Block a user