[GPT-OSS] Structure_Tag support for gpt-oss tool-call in cot (#25515)

Signed-off-by: Hanchenli <lihanc2002@gmail.com>
Signed-off-by: Hanchenli <61769611+Hanchenli@users.noreply.github.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
Co-authored-by: Wei Wei <weiweinpu@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Hanchenli
2025-10-17 21:55:54 -07:00
committed by GitHub
parent c312320764
commit 7c572544e4
14 changed files with 911 additions and 32 deletions

View File

@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for GPT-OSS structural tags functionality (PR #25515)."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.openai.protocol import (
StructuredOutputsParams,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.reasoning.gptoss_reasoning_parser import (
GptOssReasoningParser,
)
class TestGptOssStructuralTagsIntegration:
"""Integration tests for structural tags in GPT-OSS tool calls."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
return tokenizer
@pytest.fixture
def gptoss_parser(self, mock_tokenizer):
"""Create a real GptOssReasoningParser instance."""
return GptOssReasoningParser(mock_tokenizer)
@pytest.fixture
def tool_server_with_python(self):
"""Create a tool server with Python tool enabled."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
return tool_server
@pytest.fixture
def tool_server_empty(self):
"""Create a tool server with no tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(return_value=False)
return tool_server
def test_end_to_end_no_tools(self, gptoss_parser):
"""Test end-to-end flow when no tools are available."""
# Test the parser directly
result = gptoss_parser.prepare_structured_tag(None, None)
parsed_result = json.loads(result)
# Verify basic structure
assert parsed_result["type"] == "structural_tag"
assert parsed_result["format"]["type"] == "triggered_tags"
assert len(parsed_result["format"]["tags"]) == 1
# Verify only analysis channel is allowed
analysis_tag = parsed_result["format"]["tags"][0]
assert analysis_tag["begin"] == "<|channel|>analysis<|message|>"
assert analysis_tag["content"]["type"] == "any_text"
assert analysis_tag["end"] == "<|end|>"
# Verify triggers
assert parsed_result["format"]["triggers"] == ["<|channel|>analysis"]
assert parsed_result["format"]["stop_after_first"] is False
def test_end_to_end_with_python_tool(self, gptoss_parser, tool_server_with_python):
"""Test end-to-end flow with Python tool enabled."""
result = gptoss_parser.prepare_structured_tag(None, tool_server_with_python)
parsed_result = json.loads(result)
# Should have analysis tag + 2 python tags
assert len(parsed_result["format"]["tags"]) == 3
# Verify all expected tags are present
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
expected_begins = [
"<|channel|>analysis<|message|>",
"<|channel|>commentary to=python",
"<|channel|>analysis to=python",
]
for expected in expected_begins:
assert expected in tag_begins
# Verify triggers include commentary
assert "<|channel|>analysis" in parsed_result["format"]["triggers"]
assert "<|channel|>commentary to=" in parsed_result["format"]["triggers"]
def test_structured_outputs_params_integration(
self, gptoss_parser, tool_server_with_python
):
"""Test integration with StructuredOutputsParams."""
# Generate structural tag
structural_tag = gptoss_parser.prepare_structured_tag(
None, tool_server_with_python
)
# Create StructuredOutputsParams
params = StructuredOutputsParams(structural_tag=structural_tag)
# Verify the tag is properly stored and accessible
assert params.structural_tag == structural_tag
# Verify the tag is valid JSON
parsed_tag = json.loads(params.structural_tag)
assert parsed_tag["type"] == "structural_tag"
@pytest.mark.parametrize(
"browser, python, container, expected_tags",
[
# No tools
(False, False, False, 1),
# Single tool
(True, False, False, 3),
# Multiple tools
(True, True, False, 5),
# All tools
(True, True, True, 7),
],
)
def test_tool_server_interaction_flow(
self, gptoss_parser, browser, python, container, expected_tags
):
"""Test the complete tool server interaction flow."""
# Create a mock ToolServer
tool_server = Mock(spec=ToolServer)
# Simulate tool availability based on parameters
tool_server.has_tool = Mock(
side_effect=lambda tool: {
"browser": browser,
"python": python,
"container": container,
}.get(tool, False)
)
# Run the parser and verify results
result = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_result = json.loads(result)
# Validate number of tags
assert len(parsed_result["format"]["tags"]) == expected_tags
# Verify tool-specific tags exist for enabled tools
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
for tool, enabled in {
"browser": browser,
"python": python,
"container": container,
}.items():
if enabled:
assert f"<|channel|>commentary to={tool}" in tag_begins
assert f"<|channel|>analysis to={tool}" in tag_begins
def test_original_tag_preservation(self, gptoss_parser, tool_server_with_python):
"""Test that original tags are preserved when provided."""
original_tag = '{"type": "custom_tag", "data": "preserved"}'
result = gptoss_parser.prepare_structured_tag(
original_tag, tool_server_with_python
)
# Should return original tag unchanged
assert result == original_tag
@pytest.mark.parametrize(
"tools",
[
[],
["browser"],
["python"],
["container"],
["browser", "python"],
["browser", "container"],
["python", "container"],
["browser", "python", "container"],
],
)
def test_json_validity_comprehensive(self, gptoss_parser, tools):
"""Test JSON validity across all possible tool combinations."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool in tools)
result = gptoss_parser.prepare_structured_tag(None, tool_server)
# Should be valid JSON
parsed_result = json.loads(result)
# Should have correct structure
assert parsed_result["type"] == "structural_tag"
assert "format" in parsed_result
assert "tags" in parsed_result["format"]
assert "triggers" in parsed_result["format"]
# Tag count should be: 1 (analysis) + 2 * len(tools)
expected_tag_count = 1 + (2 * len(tools))
assert len(parsed_result["format"]["tags"]) == expected_tag_count
def test_error_handling_invalid_tool_server(self, gptoss_parser):
"""Test error handling with invalid tool server."""
# Tool server that raises exceptions
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=Exception("Tool server error"))
# Should handle gracefully and still return a valid tag
with pytest.raises(Exception, match="Tool server error"):
gptoss_parser.prepare_structured_tag(None, tool_server)
def test_concurrent_requests_isolation(self, gptoss_parser):
"""Test that concurrent requests don't interfere with each other."""
# Simulate concurrent requests with different tool servers
tool_server_1 = Mock(spec=ToolServer)
tool_server_1.has_tool = Mock(side_effect=lambda tool: tool == "python")
tool_server_2 = Mock(spec=ToolServer)
tool_server_2.has_tool = Mock(side_effect=lambda tool: tool == "browser")
# Generate tags concurrently
result_1 = gptoss_parser.prepare_structured_tag(None, tool_server_1)
result_2 = gptoss_parser.prepare_structured_tag(None, tool_server_2)
# Parse results
parsed_1 = json.loads(result_1)
parsed_2 = json.loads(result_2)
# Verify they have different tool configurations
tags_1 = [tag["begin"] for tag in parsed_1["format"]["tags"]]
tags_2 = [tag["begin"] for tag in parsed_2["format"]["tags"]]
# Result 1 should have python tags
assert "<|channel|>commentary to=python" in tags_1
assert "<|channel|>commentary to=browser" not in tags_1
# Result 2 should have browser tags
assert "<|channel|>commentary to=browser" in tags_2
assert "<|channel|>commentary to=python" not in tags_2
def test_tag_format_consistency(self, gptoss_parser):
"""Test that all generated tags follow consistent format."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(
side_effect=lambda tool: tool in ["python", "browser"]
)
result = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_result = json.loads(result)
# Verify all tags have required fields
for tag in parsed_result["format"]["tags"]:
assert "begin" in tag
assert "content" in tag
assert "end" in tag
assert tag["content"]["type"] == "any_text"
assert tag["end"] == "<|end|>"
# Verify begin format
assert tag["begin"].startswith("<|channel|>")
def test_trigger_configuration(self, gptoss_parser):
"""Test trigger configuration for different tool setups."""
# Test with no tools
result_no_tools = gptoss_parser.prepare_structured_tag(None, None)
parsed_no_tools = json.loads(result_no_tools)
assert parsed_no_tools["format"]["triggers"] == ["<|channel|>analysis"]
# Test with tools
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
result_with_tools = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_with_tools = json.loads(result_with_tools)
expected_triggers = ["<|channel|>analysis", "<|channel|>commentary to="]
assert set(parsed_with_tools["format"]["triggers"]) == set(expected_triggers)

View File

@ -864,3 +864,49 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
# non-structured outputs requests should not return a valid JSON here # non-structured outputs requests should not return a valid JSON here
with pytest.raises(ValueError): with pytest.raises(ValueError):
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
def test_structured_output_with_structural_tag(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
guided_decoding_backend=guided_decoding_backend,
)
structural_tag_config = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"tags": [
{"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"}
],
"triggers": ["hello"],
"stop_after_first": False,
},
}
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
guided_decoding=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)
prompt = "Hello and repete hello 10 times, do not say anything else. Only say hello hello hello, now start"
outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "hello_flag" in generated_text, (
f"Expected 'hello_flag' to be in generated text, but got: {generated_text}"
)

View File

@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for GPT-OSS structural tag support in reasoning (PR #25515)."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.tool_server import ToolServer
from vllm.reasoning.gptoss_reasoning_parser import (
GptOssReasoningParser,
from_builtin_tool_to_tag,
no_func_reaonsing_tag,
tag_with_builtin_funcs,
)
class TestGptOssReasoningParser:
"""Test cases for GptOssReasoningParser structural tag functionality."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer for testing."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
return tokenizer
@pytest.fixture
def reasoning_parser(self, mock_tokenizer):
"""Create a GptOssReasoningParser instance."""
return GptOssReasoningParser(mock_tokenizer)
@pytest.fixture
def mock_tool_server_empty(self):
"""Create a mock ToolServer with no tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(return_value=False)
return tool_server
@pytest.fixture
def mock_tool_server_with_browser(self):
"""Create a mock ToolServer with browser tool."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "browser")
return tool_server
@pytest.fixture
def mock_tool_server_with_all_tools(self):
"""Create a mock ToolServer with all builtin tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(
side_effect=lambda tool: tool in ["browser", "python", "container"]
)
return tool_server
def test_prepare_structured_tag_no_tool_server(self, reasoning_parser):
"""Test prepare_structured_tag with no tool server."""
result = reasoning_parser.prepare_structured_tag(None, None)
expected = json.dumps(no_func_reaonsing_tag)
assert result == expected
# Verify the structure is correct
parsed = json.loads(result)
assert parsed["type"] == "structural_tag"
assert parsed["format"]["type"] == "triggered_tags"
assert len(parsed["format"]["tags"]) == 1
assert parsed["format"]["tags"][0]["begin"] == "<|channel|>analysis<|message|>"
assert parsed["format"]["triggers"] == ["<|channel|>analysis"]
def test_prepare_structured_tag_with_all_tools(
self, reasoning_parser, mock_tool_server_with_all_tools
):
"""Test prepare_structured_tag with all builtin tools."""
result = reasoning_parser.prepare_structured_tag(
None, mock_tool_server_with_all_tools
)
parsed = json.loads(result)
# Should have analysis tag + tags for all 3 tools (2 tags each)
assert len(parsed["format"]["tags"]) == 7 # 1 analysis + 6 tool tags
# Check all tool tags are present
tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]]
for tool in ["browser", "python", "container"]:
assert f"<|channel|>commentary to={tool}" in tag_begins
assert f"<|channel|>analysis to={tool}" in tag_begins
def test_prepare_structured_tag_with_original_tag(self, reasoning_parser):
"""Test prepare_structured_tag when original_tag is provided."""
original_tag = '{"custom": "tag"}'
result = reasoning_parser.prepare_structured_tag(original_tag, None)
# Should return the original tag unchanged
assert result == original_tag
def test_from_builtin_tool_to_tag(self):
"""Test from_builtin_tool_to_tag function."""
tags = from_builtin_tool_to_tag("python")
assert len(tags) == 2
assert tags[0]["begin"] == "<|channel|>commentary to=python"
assert tags[0]["content"]["type"] == "any_text"
assert tags[0]["end"] == "<|end|>"
assert tags[1]["begin"] == "<|channel|>analysis to=python"
assert tags[1]["content"]["type"] == "any_text"
assert tags[1]["end"] == "<|end|>"
def test_tag_with_builtin_funcs(self):
"""Test tag_with_builtin_funcs function."""
builtin_tools = ["browser", "python"]
result = tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tools)
assert result["type"] == "structural_tag"
# Should have original analysis tag + 2 tags per tool
assert len(result["format"]["tags"]) == 5 # 1 + 2*2
# Should have added commentary trigger
assert "<|channel|>commentary to=" in result["format"]["triggers"]
assert "<|channel|>analysis" in result["format"]["triggers"]
def test_tag_structure_invariants(self):
"""Test that the basic tag structure follows expected format."""
# Test the base no_func_reaonsing_tag structure
assert no_func_reaonsing_tag["type"] == "structural_tag"
assert no_func_reaonsing_tag["format"]["type"] == "triggered_tags"
assert no_func_reaonsing_tag["format"]["stop_after_first"] is False
# Verify analysis tag structure
analysis_tag = no_func_reaonsing_tag["format"]["tags"][0]
assert analysis_tag["begin"] == "<|channel|>analysis<|message|>"
assert analysis_tag["content"]["type"] == "any_text"
assert analysis_tag["end"] == "<|end|>"
def test_json_serialization_valid(
self, reasoning_parser, mock_tool_server_with_all_tools
):
"""Test that all generated tags produce valid JSON."""
# Test with no tool server
result1 = reasoning_parser.prepare_structured_tag(None, None)
json.loads(result1) # Should not raise
# Test with empty tool server
empty_server = Mock(spec=ToolServer)
empty_server.has_tool = Mock(return_value=False)
result2 = reasoning_parser.prepare_structured_tag(None, empty_server)
json.loads(result2) # Should not raise
# Test with tools
result3 = reasoning_parser.prepare_structured_tag(
None, mock_tool_server_with_all_tools
)
json.loads(result3) # Should not raise
@pytest.mark.parametrize("tool_name", ["browser", "python", "container"])
def test_single_tool_integration(self, reasoning_parser, tool_name):
"""Test integration with individual tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == tool_name)
result = reasoning_parser.prepare_structured_tag(None, tool_server)
parsed = json.loads(result)
# Should have 1 analysis + 2 tool-specific tags
assert len(parsed["format"]["tags"]) == 3
tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]]
assert f"<|channel|>commentary to={tool_name}" in tag_begins
assert f"<|channel|>analysis to={tool_name}" in tag_begins

View File

@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for reasoning-aware structured output functionality (PR #25515)."""
from unittest.mock import Mock
import pytest
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.reasoning import ReasoningParser
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
class TestReasoningStructuredOutput:
"""Test reasoning-aware structured output functionality."""
@pytest.fixture
def mock_model_config(self):
"""Create a mock ModelConfig."""
config = Mock(spec=ModelConfig)
config.skip_tokenizer_init = True # Skip tokenizer init to avoid network calls
config.get_vocab_size = Mock(return_value=50000)
# Add missing runner_type attribute that tokenizer initialization expects
config.runner_type = "generate"
# Add other attributes that tokenizer initialization might need
config.tokenizer = "test-tokenizer"
config.tokenizer_mode = "auto"
config.trust_remote_code = False
config.tokenizer_revision = None
return config
@pytest.fixture
def mock_scheduler_config(self):
"""Create a mock SchedulerConfig."""
config = Mock(spec=SchedulerConfig)
config.max_num_seqs = 128
return config
@pytest.fixture
def mock_vllm_config(self, mock_model_config, mock_scheduler_config):
"""Create a mock VllmConfig."""
config = Mock(spec=VllmConfig)
config.model_config = mock_model_config
config.scheduler_config = mock_scheduler_config
config.structured_outputs_config = Mock()
config.structured_outputs_config.reasoning_parser = None
config.structured_outputs_config.enable_in_reasoning = False
config.speculative_config = None
return config
@pytest.fixture
def mock_reasoning_parser(self):
"""Create a mock ReasoningParser."""
parser = Mock(spec=ReasoningParser)
parser.is_reasoning_end = Mock(return_value=False)
return parser
@pytest.fixture
def mock_request_with_structured_output(self):
"""Create a mock request with structured output."""
request = Mock(spec=Request)
request.structured_output_request = Mock()
request.structured_output_request.reasoning_ended = None
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.is_terminated = Mock(
return_value=False
)
request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
return request
def test_should_fill_bitmask_with_enable_in_reasoning(
self, mock_vllm_config, mock_request_with_structured_output
):
"""Test should_fill_bitmask when enable_in_reasoning is True."""
# Enable enable_in_reasoning
mock_vllm_config.structured_outputs_config.enable_in_reasoning = True
manager = StructuredOutputManager(mock_vllm_config)
# Should always return True when enable_in_reasoning is enabled
result = manager.should_fill_bitmask(mock_request_with_structured_output)
assert result is True
def test_should_fill_bitmask_without_enable_in_reasoning(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_fill_bitmask when enable_in_reasoning is False."""
# Keep enable_in_reasoning as False (default)
config = mock_vllm_config.structured_outputs_config
assert config.enable_in_reasoning is False
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Mock reasoning not ended
mock_reasoning_parser.is_reasoning_end.return_value = False
result = manager.should_fill_bitmask(mock_request_with_structured_output)
# Should set reasoning_ended and return its value
assert (
mock_request_with_structured_output.structured_output_request.reasoning_ended
is False
)
assert result is False
def test_should_fill_bitmask_no_reasoner(
self, mock_vllm_config, mock_request_with_structured_output
):
"""Test should_fill_bitmask when no reasoner is configured."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = None
result = manager.should_fill_bitmask(mock_request_with_structured_output)
# Should default to True when no reasoner
assert result is True
def test_should_advance_with_enable_in_reasoning(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when enable_in_reasoning is True."""
# Enable enable_in_reasoning
mock_vllm_config.structured_outputs_config.enable_in_reasoning = True
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Should always return True when enable_in_reasoning is enabled
result = manager.should_advance(mock_request_with_structured_output)
assert result is True
def test_should_advance_reasoning_not_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning has not ended."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as not ended
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = False
mock_reasoning_parser.is_reasoning_end.return_value = False
result = manager.should_advance(mock_request_with_structured_output)
# Should return False since reasoning hasn't ended
assert result is False
def test_should_advance_reasoning_just_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning ends in current step."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as not ended initially, but ends in this step
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = False
mock_reasoning_parser.is_reasoning_end.return_value = True
result = manager.should_advance(mock_request_with_structured_output)
# Should set reasoning_ended to True but return False for this step
assert (
mock_request_with_structured_output.structured_output_request.reasoning_ended
is True
)
assert result is False
def test_should_advance_reasoning_already_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning has already ended."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as already ended
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = True
result = manager.should_advance(mock_request_with_structured_output)
# Should return True since reasoning has ended
assert result is True

View File

@ -35,6 +35,8 @@ class StructuredOutputsConfig:
reasoning_parser: str = "" reasoning_parser: str = ""
"""Select the reasoning parser depending on the model that you're using. """Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.""" This is used to parse the reasoning content into OpenAI API format."""
enable_in_reasoning: bool = False
"""Whether to use structured input for reasoning."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """

View File

@ -479,6 +479,7 @@ class EngineArgs:
VllmConfig, "structured_outputs_config" VllmConfig, "structured_outputs_config"
) )
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
# Deprecated guided decoding fields # Deprecated guided decoding fields
guided_decoding_backend: str | None = None guided_decoding_backend: str | None = None
guided_decoding_disable_fallback: bool | None = None guided_decoding_disable_fallback: bool | None = None

View File

@ -200,7 +200,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict: bool | None = None strict: bool | None = None
class StructuralTag(OpenAIBaseModel): class LegacyStructuralTag(OpenAIBaseModel):
begin: str begin: str
# schema is the field, but that causes conflicts with pydantic so # schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias # instead use structural_tag_schema with an alias
@ -208,10 +208,20 @@ class StructuralTag(OpenAIBaseModel):
end: str end: str
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[LegacyStructuralTag]
triggers: list[str]
class StructuralTagResponseFormat(OpenAIBaseModel): class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"] type: Literal["structural_tag"]
structures: list[StructuralTag] format: Any
triggers: list[str]
AnyStructuralTagResponseFormat: TypeAlias = (
LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)
class ResponseFormat(OpenAIBaseModel): class ResponseFormat(OpenAIBaseModel):
@ -220,7 +230,9 @@ class ResponseFormat(OpenAIBaseModel):
json_schema: JsonSchemaResponseFormat | None = None json_schema: JsonSchemaResponseFormat | None = None
AnyResponseFormat: TypeAlias = ResponseFormat | StructuralTagResponseFormat AnyResponseFormat: TypeAlias = (
ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
class StreamOptions(OpenAIBaseModel): class StreamOptions(OpenAIBaseModel):
@ -823,7 +835,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
elif response_format.type == "structural_tag": elif response_format.type == "structural_tag":
structural_tag = response_format structural_tag = response_format
assert structural_tag is not None and isinstance( assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
) )
s_tag_obj = structural_tag.model_dump(by_alias=True) s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj) self.structured_outputs.structural_tag = json.dumps(s_tag_obj)

View File

@ -98,7 +98,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -365,6 +365,19 @@ class OpenAIServingResponses(OpenAIServing):
context = HarmonyContext(messages, available_tools) context = HarmonyContext(messages, available_tools)
else: else:
context = SimpleContext() context = SimpleContext()
if self.reasoning_parser is not None:
reasoning_parser = self.reasoning_parser(tokenizer)
if sampling_params.structured_outputs is None:
sampling_params.structured_outputs = StructuredOutputsParams()
struct_out = sampling_params.structured_outputs
if struct_out.all_non_structural_tag_constraints_none():
sampling_params.structured_outputs.structural_tag = (
reasoning_parser.prepare_structured_tag(
sampling_params.structured_outputs.structural_tag,
self.tool_server,
)
)
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
request_id=request.request_id, request_id=request.request_id,
request_prompt=request_prompts[i], request_prompt=request_prompts[i],
@ -1919,6 +1932,7 @@ class OpenAIServingResponses(OpenAIServing):
processer = self._process_harmony_streaming_events processer = self._process_harmony_streaming_events
else: else:
processer = self._process_simple_streaming_events processer = self._process_simple_streaming_events
# TODO Hanchen make sampling params to include the structural tag
initial_response = ResponsesResponse.from_request( initial_response = ResponsesResponse.from_request(
request, request,

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from vllm.entrypoints.tool_server import ToolServer
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.collections import is_list_of from vllm.utils.collections import is_list_of
from vllm.utils.import_utils import import_from_path from vllm.utils.import_utils import import_from_path
@ -115,6 +116,17 @@ class ReasoningParser:
previously been parsed and extracted (see constructor) previously been parsed and extracted (see constructor)
""" """
def prepare_structured_tag(
self,
original_tag: str | None,
tool_server: ToolServer | None,
) -> str:
"""
Instance method that is implemented for preparing the structured tag
Otherwise, None is returned
"""
return None
class ReasoningParserManager: class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {} reasoning_parsers: dict[str, type] = {}

View File

@ -1,17 +1,61 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence from collections.abc import Sequence
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.harmony_utils import parse_chat_output
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.entrypoints.tool_server import ToolServer
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
no_func_reaonsing_tag = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"tags": [
{
"begin": "<|channel|>analysis<|message|>",
"content": {"type": "any_text"},
"end": "<|end|>",
}
],
"triggers": ["<|channel|>analysis"],
"stop_after_first": False,
},
}
def from_builtin_tool_to_tag(tool: str) -> list[dict]:
tag = [
{
"begin": f"<|channel|>commentary to={tool}",
"content": {"type": "any_text"},
"end": "<|end|>",
},
{
"begin": f"<|channel|>analysis to={tool}",
"content": {"type": "any_text"},
"end": "<|end|>",
},
]
return tag
def tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list: list[str]) -> dict:
import copy
new_tag = copy.deepcopy(no_func_reaonsing_tag)
new_tag["format"]["triggers"].append("<|channel|>commentary to=")
for tool in builtin_tool_list:
new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool))
return new_tag
@ReasoningParserManager.register_module("openai_gptoss") @ReasoningParserManager.register_module("openai_gptoss")
class GptOssReasoningParser(ReasoningParser): class GptOssReasoningParser(ReasoningParser):
@ -81,3 +125,33 @@ class GptOssReasoningParser(ReasoningParser):
raise NotImplementedError( raise NotImplementedError(
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
) )
# This function prepares the structural tag to format reasoning output
def prepare_structured_tag(
self, original_tag: str | None, tool_server: ToolServer | None
) -> str:
if original_tag is None:
if tool_server is None:
return json.dumps(no_func_reaonsing_tag)
else:
builtin_tool_list: list[str] = []
if tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if tool_server.has_tool("python"):
builtin_tool_list.append("python")
if tool_server.has_tool("container"):
builtin_tool_list.append("container")
if len(builtin_tool_list) > 0:
logger.info("Builtin_tool_list: %s", builtin_tool_list)
func_tag = json.dumps(
tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list)
)
else:
logger.info("Builtin_tool_list is empty")
func_tag = json.dumps(no_func_reaonsing_tag)
return func_tag
else:
# There is potential risk for appending the tag to the original tag
return original_tag

View File

@ -58,6 +58,7 @@ class StructuredOutputsParams:
self.choice is not None, self.choice is not None,
self.grammar is not None, self.grammar is not None,
self.json_object is not None, self.json_object is not None,
self.structural_tag is not None,
] ]
) )
if count > 1: if count > 1:
@ -66,6 +67,37 @@ class StructuredOutputsParams:
f"but multiple are specified: {self.__dict__}" f"but multiple are specified: {self.__dict__}"
) )
def all_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
"structural_tag",
)
)
def all_non_structural_tag_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
)
)
@dataclass @dataclass
class GuidedDecodingParams(StructuredOutputsParams): class GuidedDecodingParams(StructuredOutputsParams):

View File

@ -73,6 +73,10 @@ class StructuredOutputManager:
) )
self.reasoner = reasoner_cls(tokenizer=self.tokenizer) self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
self.enable_in_reasoning = (
self.vllm_config.structured_outputs_config.enable_in_reasoning
)
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: if request.structured_output_request is None:
return return
@ -274,7 +278,13 @@ class StructuredOutputManager:
return bitmask_tensor.numpy() return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool: def should_fill_bitmask(self, request: Request) -> bool:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling.
if self.reasoner is not None: if self.reasoner is not None:
if self.enable_in_reasoning:
return True
assert request.structured_output_request is not None assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None: if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = ( request.structured_output_request.reasoning_ended = (
@ -297,6 +307,10 @@ class StructuredOutputManager:
if self.reasoner is None: if self.reasoner is None:
return True return True
# if the model needs structured in reasoning, we should advance
if self.enable_in_reasoning:
return True
structured_req = request.structured_output_request structured_req = request.structured_output_request
if structured_req.reasoning_ended: if structured_req.reasoning_ended:
return True return True

View File

@ -91,18 +91,20 @@ class XgrammarBackend(StructuredOutputBackend):
ctx = self.compiler.compile_regex(grammar_spec) ctx = self.compiler.compile_regex(grammar_spec)
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
s_tag = json.loads(grammar_spec) s_tag = json.loads(grammar_spec)
tags = [ if "structures" in s_tag:
xgr.StructuralTagItem( # Falling back to deprecated method of compiling structural tag
begin=s["begin"], tags = [
schema=json.dumps(s["schema"]), xgr.StructuralTagItem(
end=s["end"], begin=s["begin"],
) schema=json.dumps(s["schema"]),
for s in s_tag["structures"] end=s["end"],
] )
structural_tag = xgr.StructuralTag.from_legacy_structural_tag( for s in s_tag["structures"]
tags, s_tag["triggers"] ]
) ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
ctx = self.compiler.compile_structural_tag(structural_tag) else:
logger.info("Compiling structural tag grammar_spec: %s", grammar_spec)
ctx = self.compiler.compile_structural_tag(grammar_spec)
else: else:
logger.error( logger.error(
"Validation should have already occurred. Please file an issue." "Validation should have already occurred. Please file an issue."
@ -320,17 +322,19 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
if so_params.structural_tag: if so_params.structural_tag:
try: try:
s_tag = json.loads(so_params.structural_tag) s_tag = json.loads(so_params.structural_tag)
tags = [
xgr.StructuralTagItem( # Using the deprecated method of compiling structural tag
begin=s["begin"], if "structures" in s_tag:
schema=json.dumps(s["schema"]), tags = [
end=s["end"], xgr.StructuralTagItem(
) begin=s["begin"],
for s in s_tag["structures"] schema=json.dumps(s["schema"]),
] end=s["end"],
structural_tag = xgr.StructuralTag.from_legacy_structural_tag( )
tags, s_tag["triggers"] for s in s_tag["structures"]
) ]
xgr.Grammar.from_structural_tag(structural_tag) xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
else:
xgr.Grammar.from_structural_tag(so_params.structural_tag)
except Exception as e: except Exception as e:
raise ValueError("Invalid structural tag specification.") from e raise ValueError("Invalid structural tag specification.") from e

View File

@ -28,7 +28,12 @@ class StructuredOutputRequest:
if sampling_params is None: if sampling_params is None:
return None return None
params = sampling_params.structured_outputs params = sampling_params.structured_outputs
return StructuredOutputRequest(params=params) if params else None if params:
if params.all_constraints_none():
return None
else:
return StructuredOutputRequest(params=params)
return None
def _check_grammar_completion(self) -> bool: def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports # NOTE: We have to lazy import to gate circular imports