[GPT-OSS] Structure_Tag support for gpt-oss tool-call in cot (#25515)

Signed-off-by: Hanchenli <lihanc2002@gmail.com>
Signed-off-by: Hanchenli <61769611+Hanchenli@users.noreply.github.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Wei Wei <wwei6@meta.com>
Co-authored-by: Wei Wei <weiweinpu@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Hanchenli
2025-10-17 21:55:54 -07:00
committed by GitHub
parent c312320764
commit 7c572544e4
14 changed files with 911 additions and 32 deletions

View File

@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for GPT-OSS structural tags functionality (PR #25515)."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.openai.protocol import (
StructuredOutputsParams,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.reasoning.gptoss_reasoning_parser import (
GptOssReasoningParser,
)
class TestGptOssStructuralTagsIntegration:
"""Integration tests for structural tags in GPT-OSS tool calls."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
return tokenizer
@pytest.fixture
def gptoss_parser(self, mock_tokenizer):
"""Create a real GptOssReasoningParser instance."""
return GptOssReasoningParser(mock_tokenizer)
@pytest.fixture
def tool_server_with_python(self):
"""Create a tool server with Python tool enabled."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
return tool_server
@pytest.fixture
def tool_server_empty(self):
"""Create a tool server with no tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(return_value=False)
return tool_server
def test_end_to_end_no_tools(self, gptoss_parser):
"""Test end-to-end flow when no tools are available."""
# Test the parser directly
result = gptoss_parser.prepare_structured_tag(None, None)
parsed_result = json.loads(result)
# Verify basic structure
assert parsed_result["type"] == "structural_tag"
assert parsed_result["format"]["type"] == "triggered_tags"
assert len(parsed_result["format"]["tags"]) == 1
# Verify only analysis channel is allowed
analysis_tag = parsed_result["format"]["tags"][0]
assert analysis_tag["begin"] == "<|channel|>analysis<|message|>"
assert analysis_tag["content"]["type"] == "any_text"
assert analysis_tag["end"] == "<|end|>"
# Verify triggers
assert parsed_result["format"]["triggers"] == ["<|channel|>analysis"]
assert parsed_result["format"]["stop_after_first"] is False
def test_end_to_end_with_python_tool(self, gptoss_parser, tool_server_with_python):
"""Test end-to-end flow with Python tool enabled."""
result = gptoss_parser.prepare_structured_tag(None, tool_server_with_python)
parsed_result = json.loads(result)
# Should have analysis tag + 2 python tags
assert len(parsed_result["format"]["tags"]) == 3
# Verify all expected tags are present
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
expected_begins = [
"<|channel|>analysis<|message|>",
"<|channel|>commentary to=python",
"<|channel|>analysis to=python",
]
for expected in expected_begins:
assert expected in tag_begins
# Verify triggers include commentary
assert "<|channel|>analysis" in parsed_result["format"]["triggers"]
assert "<|channel|>commentary to=" in parsed_result["format"]["triggers"]
def test_structured_outputs_params_integration(
self, gptoss_parser, tool_server_with_python
):
"""Test integration with StructuredOutputsParams."""
# Generate structural tag
structural_tag = gptoss_parser.prepare_structured_tag(
None, tool_server_with_python
)
# Create StructuredOutputsParams
params = StructuredOutputsParams(structural_tag=structural_tag)
# Verify the tag is properly stored and accessible
assert params.structural_tag == structural_tag
# Verify the tag is valid JSON
parsed_tag = json.loads(params.structural_tag)
assert parsed_tag["type"] == "structural_tag"
@pytest.mark.parametrize(
"browser, python, container, expected_tags",
[
# No tools
(False, False, False, 1),
# Single tool
(True, False, False, 3),
# Multiple tools
(True, True, False, 5),
# All tools
(True, True, True, 7),
],
)
def test_tool_server_interaction_flow(
self, gptoss_parser, browser, python, container, expected_tags
):
"""Test the complete tool server interaction flow."""
# Create a mock ToolServer
tool_server = Mock(spec=ToolServer)
# Simulate tool availability based on parameters
tool_server.has_tool = Mock(
side_effect=lambda tool: {
"browser": browser,
"python": python,
"container": container,
}.get(tool, False)
)
# Run the parser and verify results
result = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_result = json.loads(result)
# Validate number of tags
assert len(parsed_result["format"]["tags"]) == expected_tags
# Verify tool-specific tags exist for enabled tools
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
for tool, enabled in {
"browser": browser,
"python": python,
"container": container,
}.items():
if enabled:
assert f"<|channel|>commentary to={tool}" in tag_begins
assert f"<|channel|>analysis to={tool}" in tag_begins
def test_original_tag_preservation(self, gptoss_parser, tool_server_with_python):
"""Test that original tags are preserved when provided."""
original_tag = '{"type": "custom_tag", "data": "preserved"}'
result = gptoss_parser.prepare_structured_tag(
original_tag, tool_server_with_python
)
# Should return original tag unchanged
assert result == original_tag
@pytest.mark.parametrize(
"tools",
[
[],
["browser"],
["python"],
["container"],
["browser", "python"],
["browser", "container"],
["python", "container"],
["browser", "python", "container"],
],
)
def test_json_validity_comprehensive(self, gptoss_parser, tools):
"""Test JSON validity across all possible tool combinations."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool in tools)
result = gptoss_parser.prepare_structured_tag(None, tool_server)
# Should be valid JSON
parsed_result = json.loads(result)
# Should have correct structure
assert parsed_result["type"] == "structural_tag"
assert "format" in parsed_result
assert "tags" in parsed_result["format"]
assert "triggers" in parsed_result["format"]
# Tag count should be: 1 (analysis) + 2 * len(tools)
expected_tag_count = 1 + (2 * len(tools))
assert len(parsed_result["format"]["tags"]) == expected_tag_count
def test_error_handling_invalid_tool_server(self, gptoss_parser):
"""Test error handling with invalid tool server."""
# Tool server that raises exceptions
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=Exception("Tool server error"))
# Should handle gracefully and still return a valid tag
with pytest.raises(Exception, match="Tool server error"):
gptoss_parser.prepare_structured_tag(None, tool_server)
def test_concurrent_requests_isolation(self, gptoss_parser):
"""Test that concurrent requests don't interfere with each other."""
# Simulate concurrent requests with different tool servers
tool_server_1 = Mock(spec=ToolServer)
tool_server_1.has_tool = Mock(side_effect=lambda tool: tool == "python")
tool_server_2 = Mock(spec=ToolServer)
tool_server_2.has_tool = Mock(side_effect=lambda tool: tool == "browser")
# Generate tags concurrently
result_1 = gptoss_parser.prepare_structured_tag(None, tool_server_1)
result_2 = gptoss_parser.prepare_structured_tag(None, tool_server_2)
# Parse results
parsed_1 = json.loads(result_1)
parsed_2 = json.loads(result_2)
# Verify they have different tool configurations
tags_1 = [tag["begin"] for tag in parsed_1["format"]["tags"]]
tags_2 = [tag["begin"] for tag in parsed_2["format"]["tags"]]
# Result 1 should have python tags
assert "<|channel|>commentary to=python" in tags_1
assert "<|channel|>commentary to=browser" not in tags_1
# Result 2 should have browser tags
assert "<|channel|>commentary to=browser" in tags_2
assert "<|channel|>commentary to=python" not in tags_2
def test_tag_format_consistency(self, gptoss_parser):
"""Test that all generated tags follow consistent format."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(
side_effect=lambda tool: tool in ["python", "browser"]
)
result = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_result = json.loads(result)
# Verify all tags have required fields
for tag in parsed_result["format"]["tags"]:
assert "begin" in tag
assert "content" in tag
assert "end" in tag
assert tag["content"]["type"] == "any_text"
assert tag["end"] == "<|end|>"
# Verify begin format
assert tag["begin"].startswith("<|channel|>")
def test_trigger_configuration(self, gptoss_parser):
"""Test trigger configuration for different tool setups."""
# Test with no tools
result_no_tools = gptoss_parser.prepare_structured_tag(None, None)
parsed_no_tools = json.loads(result_no_tools)
assert parsed_no_tools["format"]["triggers"] == ["<|channel|>analysis"]
# Test with tools
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
result_with_tools = gptoss_parser.prepare_structured_tag(None, tool_server)
parsed_with_tools = json.loads(result_with_tools)
expected_triggers = ["<|channel|>analysis", "<|channel|>commentary to="]
assert set(parsed_with_tools["format"]["triggers"]) == set(expected_triggers)

View File

@ -864,3 +864,49 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
# non-structured outputs requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
def test_structured_output_with_structural_tag(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
guided_decoding_backend=guided_decoding_backend,
)
structural_tag_config = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"tags": [
{"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"}
],
"triggers": ["hello"],
"stop_after_first": False,
},
}
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=500,
guided_decoding=StructuredOutputsParams(
structural_tag=json.dumps(structural_tag_config)
),
)
prompt = "Hello and repete hello 10 times, do not say anything else. Only say hello hello hello, now start"
outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "hello_flag" in generated_text, (
f"Expected 'hello_flag' to be in generated text, but got: {generated_text}"
)

View File

@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for GPT-OSS structural tag support in reasoning (PR #25515)."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.tool_server import ToolServer
from vllm.reasoning.gptoss_reasoning_parser import (
GptOssReasoningParser,
from_builtin_tool_to_tag,
no_func_reaonsing_tag,
tag_with_builtin_funcs,
)
class TestGptOssReasoningParser:
"""Test cases for GptOssReasoningParser structural tag functionality."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer for testing."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
return tokenizer
@pytest.fixture
def reasoning_parser(self, mock_tokenizer):
"""Create a GptOssReasoningParser instance."""
return GptOssReasoningParser(mock_tokenizer)
@pytest.fixture
def mock_tool_server_empty(self):
"""Create a mock ToolServer with no tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(return_value=False)
return tool_server
@pytest.fixture
def mock_tool_server_with_browser(self):
"""Create a mock ToolServer with browser tool."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "browser")
return tool_server
@pytest.fixture
def mock_tool_server_with_all_tools(self):
"""Create a mock ToolServer with all builtin tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(
side_effect=lambda tool: tool in ["browser", "python", "container"]
)
return tool_server
def test_prepare_structured_tag_no_tool_server(self, reasoning_parser):
"""Test prepare_structured_tag with no tool server."""
result = reasoning_parser.prepare_structured_tag(None, None)
expected = json.dumps(no_func_reaonsing_tag)
assert result == expected
# Verify the structure is correct
parsed = json.loads(result)
assert parsed["type"] == "structural_tag"
assert parsed["format"]["type"] == "triggered_tags"
assert len(parsed["format"]["tags"]) == 1
assert parsed["format"]["tags"][0]["begin"] == "<|channel|>analysis<|message|>"
assert parsed["format"]["triggers"] == ["<|channel|>analysis"]
def test_prepare_structured_tag_with_all_tools(
self, reasoning_parser, mock_tool_server_with_all_tools
):
"""Test prepare_structured_tag with all builtin tools."""
result = reasoning_parser.prepare_structured_tag(
None, mock_tool_server_with_all_tools
)
parsed = json.loads(result)
# Should have analysis tag + tags for all 3 tools (2 tags each)
assert len(parsed["format"]["tags"]) == 7 # 1 analysis + 6 tool tags
# Check all tool tags are present
tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]]
for tool in ["browser", "python", "container"]:
assert f"<|channel|>commentary to={tool}" in tag_begins
assert f"<|channel|>analysis to={tool}" in tag_begins
def test_prepare_structured_tag_with_original_tag(self, reasoning_parser):
"""Test prepare_structured_tag when original_tag is provided."""
original_tag = '{"custom": "tag"}'
result = reasoning_parser.prepare_structured_tag(original_tag, None)
# Should return the original tag unchanged
assert result == original_tag
def test_from_builtin_tool_to_tag(self):
"""Test from_builtin_tool_to_tag function."""
tags = from_builtin_tool_to_tag("python")
assert len(tags) == 2
assert tags[0]["begin"] == "<|channel|>commentary to=python"
assert tags[0]["content"]["type"] == "any_text"
assert tags[0]["end"] == "<|end|>"
assert tags[1]["begin"] == "<|channel|>analysis to=python"
assert tags[1]["content"]["type"] == "any_text"
assert tags[1]["end"] == "<|end|>"
def test_tag_with_builtin_funcs(self):
"""Test tag_with_builtin_funcs function."""
builtin_tools = ["browser", "python"]
result = tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tools)
assert result["type"] == "structural_tag"
# Should have original analysis tag + 2 tags per tool
assert len(result["format"]["tags"]) == 5 # 1 + 2*2
# Should have added commentary trigger
assert "<|channel|>commentary to=" in result["format"]["triggers"]
assert "<|channel|>analysis" in result["format"]["triggers"]
def test_tag_structure_invariants(self):
"""Test that the basic tag structure follows expected format."""
# Test the base no_func_reaonsing_tag structure
assert no_func_reaonsing_tag["type"] == "structural_tag"
assert no_func_reaonsing_tag["format"]["type"] == "triggered_tags"
assert no_func_reaonsing_tag["format"]["stop_after_first"] is False
# Verify analysis tag structure
analysis_tag = no_func_reaonsing_tag["format"]["tags"][0]
assert analysis_tag["begin"] == "<|channel|>analysis<|message|>"
assert analysis_tag["content"]["type"] == "any_text"
assert analysis_tag["end"] == "<|end|>"
def test_json_serialization_valid(
self, reasoning_parser, mock_tool_server_with_all_tools
):
"""Test that all generated tags produce valid JSON."""
# Test with no tool server
result1 = reasoning_parser.prepare_structured_tag(None, None)
json.loads(result1) # Should not raise
# Test with empty tool server
empty_server = Mock(spec=ToolServer)
empty_server.has_tool = Mock(return_value=False)
result2 = reasoning_parser.prepare_structured_tag(None, empty_server)
json.loads(result2) # Should not raise
# Test with tools
result3 = reasoning_parser.prepare_structured_tag(
None, mock_tool_server_with_all_tools
)
json.loads(result3) # Should not raise
@pytest.mark.parametrize("tool_name", ["browser", "python", "container"])
def test_single_tool_integration(self, reasoning_parser, tool_name):
"""Test integration with individual tools."""
tool_server = Mock(spec=ToolServer)
tool_server.has_tool = Mock(side_effect=lambda tool: tool == tool_name)
result = reasoning_parser.prepare_structured_tag(None, tool_server)
parsed = json.loads(result)
# Should have 1 analysis + 2 tool-specific tags
assert len(parsed["format"]["tags"]) == 3
tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]]
assert f"<|channel|>commentary to={tool_name}" in tag_begins
assert f"<|channel|>analysis to={tool_name}" in tag_begins

View File

@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for reasoning-aware structured output functionality (PR #25515)."""
from unittest.mock import Mock
import pytest
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.reasoning import ReasoningParser
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
class TestReasoningStructuredOutput:
"""Test reasoning-aware structured output functionality."""
@pytest.fixture
def mock_model_config(self):
"""Create a mock ModelConfig."""
config = Mock(spec=ModelConfig)
config.skip_tokenizer_init = True # Skip tokenizer init to avoid network calls
config.get_vocab_size = Mock(return_value=50000)
# Add missing runner_type attribute that tokenizer initialization expects
config.runner_type = "generate"
# Add other attributes that tokenizer initialization might need
config.tokenizer = "test-tokenizer"
config.tokenizer_mode = "auto"
config.trust_remote_code = False
config.tokenizer_revision = None
return config
@pytest.fixture
def mock_scheduler_config(self):
"""Create a mock SchedulerConfig."""
config = Mock(spec=SchedulerConfig)
config.max_num_seqs = 128
return config
@pytest.fixture
def mock_vllm_config(self, mock_model_config, mock_scheduler_config):
"""Create a mock VllmConfig."""
config = Mock(spec=VllmConfig)
config.model_config = mock_model_config
config.scheduler_config = mock_scheduler_config
config.structured_outputs_config = Mock()
config.structured_outputs_config.reasoning_parser = None
config.structured_outputs_config.enable_in_reasoning = False
config.speculative_config = None
return config
@pytest.fixture
def mock_reasoning_parser(self):
"""Create a mock ReasoningParser."""
parser = Mock(spec=ReasoningParser)
parser.is_reasoning_end = Mock(return_value=False)
return parser
@pytest.fixture
def mock_request_with_structured_output(self):
"""Create a mock request with structured output."""
request = Mock(spec=Request)
request.structured_output_request = Mock()
request.structured_output_request.reasoning_ended = None
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.is_terminated = Mock(
return_value=False
)
request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
return request
def test_should_fill_bitmask_with_enable_in_reasoning(
self, mock_vllm_config, mock_request_with_structured_output
):
"""Test should_fill_bitmask when enable_in_reasoning is True."""
# Enable enable_in_reasoning
mock_vllm_config.structured_outputs_config.enable_in_reasoning = True
manager = StructuredOutputManager(mock_vllm_config)
# Should always return True when enable_in_reasoning is enabled
result = manager.should_fill_bitmask(mock_request_with_structured_output)
assert result is True
def test_should_fill_bitmask_without_enable_in_reasoning(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_fill_bitmask when enable_in_reasoning is False."""
# Keep enable_in_reasoning as False (default)
config = mock_vllm_config.structured_outputs_config
assert config.enable_in_reasoning is False
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Mock reasoning not ended
mock_reasoning_parser.is_reasoning_end.return_value = False
result = manager.should_fill_bitmask(mock_request_with_structured_output)
# Should set reasoning_ended and return its value
assert (
mock_request_with_structured_output.structured_output_request.reasoning_ended
is False
)
assert result is False
def test_should_fill_bitmask_no_reasoner(
self, mock_vllm_config, mock_request_with_structured_output
):
"""Test should_fill_bitmask when no reasoner is configured."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = None
result = manager.should_fill_bitmask(mock_request_with_structured_output)
# Should default to True when no reasoner
assert result is True
def test_should_advance_with_enable_in_reasoning(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when enable_in_reasoning is True."""
# Enable enable_in_reasoning
mock_vllm_config.structured_outputs_config.enable_in_reasoning = True
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Should always return True when enable_in_reasoning is enabled
result = manager.should_advance(mock_request_with_structured_output)
assert result is True
def test_should_advance_reasoning_not_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning has not ended."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as not ended
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = False
mock_reasoning_parser.is_reasoning_end.return_value = False
result = manager.should_advance(mock_request_with_structured_output)
# Should return False since reasoning hasn't ended
assert result is False
def test_should_advance_reasoning_just_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning ends in current step."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as not ended initially, but ends in this step
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = False
mock_reasoning_parser.is_reasoning_end.return_value = True
result = manager.should_advance(mock_request_with_structured_output)
# Should set reasoning_ended to True but return False for this step
assert (
mock_request_with_structured_output.structured_output_request.reasoning_ended
is True
)
assert result is False
def test_should_advance_reasoning_already_ended(
self,
mock_vllm_config,
mock_request_with_structured_output,
mock_reasoning_parser,
):
"""Test should_advance when reasoning has already ended."""
manager = StructuredOutputManager(mock_vllm_config)
manager.reasoner = mock_reasoning_parser
# Set reasoning as already ended
(
mock_request_with_structured_output.structured_output_request
).reasoning_ended = True
result = manager.should_advance(mock_request_with_structured_output)
# Should return True since reasoning has ended
assert result is True

View File

@ -35,6 +35,8 @@ class StructuredOutputsConfig:
reasoning_parser: str = ""
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format."""
enable_in_reasoning: bool = False
"""Whether to use structured input for reasoning."""
def compute_hash(self) -> str:
"""

View File

@ -479,6 +479,7 @@ class EngineArgs:
VllmConfig, "structured_outputs_config"
)
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
# Deprecated guided decoding fields
guided_decoding_backend: str | None = None
guided_decoding_disable_fallback: bool | None = None

View File

@ -200,7 +200,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict: bool | None = None
class StructuralTag(OpenAIBaseModel):
class LegacyStructuralTag(OpenAIBaseModel):
begin: str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
@ -208,10 +208,20 @@ class StructuralTag(OpenAIBaseModel):
end: str
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[LegacyStructuralTag]
triggers: list[str]
class StructuralTagResponseFormat(OpenAIBaseModel):
type: Literal["structural_tag"]
structures: list[StructuralTag]
triggers: list[str]
format: Any
AnyStructuralTagResponseFormat: TypeAlias = (
LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)
class ResponseFormat(OpenAIBaseModel):
@ -220,7 +230,9 @@ class ResponseFormat(OpenAIBaseModel):
json_schema: JsonSchemaResponseFormat | None = None
AnyResponseFormat: TypeAlias = ResponseFormat | StructuralTagResponseFormat
AnyResponseFormat: TypeAlias = (
ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
class StreamOptions(OpenAIBaseModel):
@ -823,7 +835,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)

View File

@ -98,7 +98,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
@ -365,6 +365,19 @@ class OpenAIServingResponses(OpenAIServing):
context = HarmonyContext(messages, available_tools)
else:
context = SimpleContext()
if self.reasoning_parser is not None:
reasoning_parser = self.reasoning_parser(tokenizer)
if sampling_params.structured_outputs is None:
sampling_params.structured_outputs = StructuredOutputsParams()
struct_out = sampling_params.structured_outputs
if struct_out.all_non_structural_tag_constraints_none():
sampling_params.structured_outputs.structural_tag = (
reasoning_parser.prepare_structured_tag(
sampling_params.structured_outputs.structural_tag,
self.tool_server,
)
)
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
@ -1919,6 +1932,7 @@ class OpenAIServingResponses(OpenAIServing):
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events
# TODO Hanchen make sampling params to include the structural tag
initial_response = ResponsesResponse.from_request(
request,

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any
from vllm.entrypoints.tool_server import ToolServer
from vllm.logger import init_logger
from vllm.utils.collections import is_list_of
from vllm.utils.import_utils import import_from_path
@ -115,6 +116,17 @@ class ReasoningParser:
previously been parsed and extracted (see constructor)
"""
def prepare_structured_tag(
self,
original_tag: str | None,
tool_server: ToolServer | None,
) -> str:
"""
Instance method that is implemented for preparing the structured tag
Otherwise, None is returned
"""
return None
class ReasoningParserManager:
reasoning_parsers: dict[str, type] = {}

View File

@ -1,17 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.harmony_utils import parse_chat_output
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.entrypoints.tool_server import ToolServer
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
no_func_reaonsing_tag = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"tags": [
{
"begin": "<|channel|>analysis<|message|>",
"content": {"type": "any_text"},
"end": "<|end|>",
}
],
"triggers": ["<|channel|>analysis"],
"stop_after_first": False,
},
}
def from_builtin_tool_to_tag(tool: str) -> list[dict]:
tag = [
{
"begin": f"<|channel|>commentary to={tool}",
"content": {"type": "any_text"},
"end": "<|end|>",
},
{
"begin": f"<|channel|>analysis to={tool}",
"content": {"type": "any_text"},
"end": "<|end|>",
},
]
return tag
def tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list: list[str]) -> dict:
import copy
new_tag = copy.deepcopy(no_func_reaonsing_tag)
new_tag["format"]["triggers"].append("<|channel|>commentary to=")
for tool in builtin_tool_list:
new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool))
return new_tag
@ReasoningParserManager.register_module("openai_gptoss")
class GptOssReasoningParser(ReasoningParser):
@ -81,3 +125,33 @@ class GptOssReasoningParser(ReasoningParser):
raise NotImplementedError(
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
)
# This function prepares the structural tag to format reasoning output
def prepare_structured_tag(
self, original_tag: str | None, tool_server: ToolServer | None
) -> str:
if original_tag is None:
if tool_server is None:
return json.dumps(no_func_reaonsing_tag)
else:
builtin_tool_list: list[str] = []
if tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if tool_server.has_tool("python"):
builtin_tool_list.append("python")
if tool_server.has_tool("container"):
builtin_tool_list.append("container")
if len(builtin_tool_list) > 0:
logger.info("Builtin_tool_list: %s", builtin_tool_list)
func_tag = json.dumps(
tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list)
)
else:
logger.info("Builtin_tool_list is empty")
func_tag = json.dumps(no_func_reaonsing_tag)
return func_tag
else:
# There is potential risk for appending the tag to the original tag
return original_tag

View File

@ -58,6 +58,7 @@ class StructuredOutputsParams:
self.choice is not None,
self.grammar is not None,
self.json_object is not None,
self.structural_tag is not None,
]
)
if count > 1:
@ -66,6 +67,37 @@ class StructuredOutputsParams:
f"but multiple are specified: {self.__dict__}"
)
def all_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
"structural_tag",
)
)
def all_non_structural_tag_constraints_none(self) -> bool:
"""
Returns True if all structured-output constraint fields are None.
"""
return all(
getattr(self, field) is None
for field in (
"json",
"regex",
"choice",
"grammar",
"json_object",
)
)
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):

View File

@ -73,6 +73,10 @@ class StructuredOutputManager:
)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
self.enable_in_reasoning = (
self.vllm_config.structured_outputs_config.enable_in_reasoning
)
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return
@ -274,7 +278,13 @@ class StructuredOutputManager:
return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling.
if self.reasoner is not None:
if self.enable_in_reasoning:
return True
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = (
@ -297,6 +307,10 @@ class StructuredOutputManager:
if self.reasoner is None:
return True
# if the model needs structured in reasoning, we should advance
if self.enable_in_reasoning:
return True
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True

View File

@ -91,6 +91,8 @@ class XgrammarBackend(StructuredOutputBackend):
ctx = self.compiler.compile_regex(grammar_spec)
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
s_tag = json.loads(grammar_spec)
if "structures" in s_tag:
# Falling back to deprecated method of compiling structural tag
tags = [
xgr.StructuralTagItem(
begin=s["begin"],
@ -99,10 +101,10 @@ class XgrammarBackend(StructuredOutputBackend):
)
for s in s_tag["structures"]
]
structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
tags, s_tag["triggers"]
)
ctx = self.compiler.compile_structural_tag(structural_tag)
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
else:
logger.info("Compiling structural tag grammar_spec: %s", grammar_spec)
ctx = self.compiler.compile_structural_tag(grammar_spec)
else:
logger.error(
"Validation should have already occurred. Please file an issue."
@ -320,6 +322,9 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
if so_params.structural_tag:
try:
s_tag = json.loads(so_params.structural_tag)
# Using the deprecated method of compiling structural tag
if "structures" in s_tag:
tags = [
xgr.StructuralTagItem(
begin=s["begin"],
@ -328,9 +333,8 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
)
for s in s_tag["structures"]
]
structural_tag = xgr.StructuralTag.from_legacy_structural_tag(
tags, s_tag["triggers"]
)
xgr.Grammar.from_structural_tag(structural_tag)
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
else:
xgr.Grammar.from_structural_tag(so_params.structural_tag)
except Exception as e:
raise ValueError("Invalid structural tag specification.") from e

View File

@ -28,7 +28,12 @@ class StructuredOutputRequest:
if sampling_params is None:
return None
params = sampling_params.structured_outputs
return StructuredOutputRequest(params=params) if params else None
if params:
if params.all_constraints_none():
return None
else:
return StructuredOutputRequest(params=params)
return None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports