diff --git a/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py new file mode 100644 index 0000000000..fbfae4f268 --- /dev/null +++ b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for GPT-OSS structural tags functionality (PR #25515).""" + +import json +from unittest.mock import Mock + +import pytest + +from vllm.entrypoints.openai.protocol import ( + StructuredOutputsParams, +) +from vllm.entrypoints.tool_server import ToolServer +from vllm.reasoning.gptoss_reasoning_parser import ( + GptOssReasoningParser, +) + + +class TestGptOssStructuralTagsIntegration: + """Integration tests for structural tags in GPT-OSS tool calls.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + return tokenizer + + @pytest.fixture + def gptoss_parser(self, mock_tokenizer): + """Create a real GptOssReasoningParser instance.""" + return GptOssReasoningParser(mock_tokenizer) + + @pytest.fixture + def tool_server_with_python(self): + """Create a tool server with Python tool enabled.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python") + return tool_server + + @pytest.fixture + def tool_server_empty(self): + """Create a tool server with no tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(return_value=False) + return tool_server + + def test_end_to_end_no_tools(self, gptoss_parser): + """Test end-to-end flow when no tools are available.""" + # Test the parser directly + result = gptoss_parser.prepare_structured_tag(None, None) + parsed_result = json.loads(result) + + # Verify basic structure + assert parsed_result["type"] == "structural_tag" + assert parsed_result["format"]["type"] == "triggered_tags" + assert len(parsed_result["format"]["tags"]) == 1 + + # Verify only analysis channel is allowed + analysis_tag = parsed_result["format"]["tags"][0] + assert analysis_tag["begin"] == "<|channel|>analysis<|message|>" + assert analysis_tag["content"]["type"] == "any_text" + assert analysis_tag["end"] == "<|end|>" + + # Verify triggers + assert parsed_result["format"]["triggers"] == ["<|channel|>analysis"] + assert parsed_result["format"]["stop_after_first"] is False + + def test_end_to_end_with_python_tool(self, gptoss_parser, tool_server_with_python): + """Test end-to-end flow with Python tool enabled.""" + result = gptoss_parser.prepare_structured_tag(None, tool_server_with_python) + parsed_result = json.loads(result) + + # Should have analysis tag + 2 python tags + assert len(parsed_result["format"]["tags"]) == 3 + + # Verify all expected tags are present + tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]] + expected_begins = [ + "<|channel|>analysis<|message|>", + "<|channel|>commentary to=python", + "<|channel|>analysis to=python", + ] + + for expected in expected_begins: + assert expected in tag_begins + + # Verify triggers include commentary + assert "<|channel|>analysis" in parsed_result["format"]["triggers"] + assert "<|channel|>commentary to=" in parsed_result["format"]["triggers"] + + def test_structured_outputs_params_integration( + self, gptoss_parser, tool_server_with_python + ): + """Test integration with StructuredOutputsParams.""" + # Generate structural tag + structural_tag = gptoss_parser.prepare_structured_tag( + None, tool_server_with_python + ) + + # Create StructuredOutputsParams + params = StructuredOutputsParams(structural_tag=structural_tag) + + # Verify the tag is properly stored and accessible + assert params.structural_tag == structural_tag + + # Verify the tag is valid JSON + parsed_tag = json.loads(params.structural_tag) + assert parsed_tag["type"] == "structural_tag" + + @pytest.mark.parametrize( + "browser, python, container, expected_tags", + [ + # No tools + (False, False, False, 1), + # Single tool + (True, False, False, 3), + # Multiple tools + (True, True, False, 5), + # All tools + (True, True, True, 7), + ], + ) + def test_tool_server_interaction_flow( + self, gptoss_parser, browser, python, container, expected_tags + ): + """Test the complete tool server interaction flow.""" + + # Create a mock ToolServer + tool_server = Mock(spec=ToolServer) + + # Simulate tool availability based on parameters + tool_server.has_tool = Mock( + side_effect=lambda tool: { + "browser": browser, + "python": python, + "container": container, + }.get(tool, False) + ) + + # Run the parser and verify results + result = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_result = json.loads(result) + + # Validate number of tags + assert len(parsed_result["format"]["tags"]) == expected_tags + + # Verify tool-specific tags exist for enabled tools + tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]] + for tool, enabled in { + "browser": browser, + "python": python, + "container": container, + }.items(): + if enabled: + assert f"<|channel|>commentary to={tool}" in tag_begins + assert f"<|channel|>analysis to={tool}" in tag_begins + + def test_original_tag_preservation(self, gptoss_parser, tool_server_with_python): + """Test that original tags are preserved when provided.""" + original_tag = '{"type": "custom_tag", "data": "preserved"}' + + result = gptoss_parser.prepare_structured_tag( + original_tag, tool_server_with_python + ) + + # Should return original tag unchanged + assert result == original_tag + + @pytest.mark.parametrize( + "tools", + [ + [], + ["browser"], + ["python"], + ["container"], + ["browser", "python"], + ["browser", "container"], + ["python", "container"], + ["browser", "python", "container"], + ], + ) + def test_json_validity_comprehensive(self, gptoss_parser, tools): + """Test JSON validity across all possible tool combinations.""" + + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool in tools) + + result = gptoss_parser.prepare_structured_tag(None, tool_server) + + # Should be valid JSON + parsed_result = json.loads(result) + + # Should have correct structure + assert parsed_result["type"] == "structural_tag" + assert "format" in parsed_result + assert "tags" in parsed_result["format"] + assert "triggers" in parsed_result["format"] + + # Tag count should be: 1 (analysis) + 2 * len(tools) + expected_tag_count = 1 + (2 * len(tools)) + assert len(parsed_result["format"]["tags"]) == expected_tag_count + + def test_error_handling_invalid_tool_server(self, gptoss_parser): + """Test error handling with invalid tool server.""" + # Tool server that raises exceptions + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=Exception("Tool server error")) + + # Should handle gracefully and still return a valid tag + with pytest.raises(Exception, match="Tool server error"): + gptoss_parser.prepare_structured_tag(None, tool_server) + + def test_concurrent_requests_isolation(self, gptoss_parser): + """Test that concurrent requests don't interfere with each other.""" + # Simulate concurrent requests with different tool servers + tool_server_1 = Mock(spec=ToolServer) + tool_server_1.has_tool = Mock(side_effect=lambda tool: tool == "python") + + tool_server_2 = Mock(spec=ToolServer) + tool_server_2.has_tool = Mock(side_effect=lambda tool: tool == "browser") + + # Generate tags concurrently + result_1 = gptoss_parser.prepare_structured_tag(None, tool_server_1) + result_2 = gptoss_parser.prepare_structured_tag(None, tool_server_2) + + # Parse results + parsed_1 = json.loads(result_1) + parsed_2 = json.loads(result_2) + + # Verify they have different tool configurations + tags_1 = [tag["begin"] for tag in parsed_1["format"]["tags"]] + tags_2 = [tag["begin"] for tag in parsed_2["format"]["tags"]] + + # Result 1 should have python tags + assert "<|channel|>commentary to=python" in tags_1 + assert "<|channel|>commentary to=browser" not in tags_1 + + # Result 2 should have browser tags + assert "<|channel|>commentary to=browser" in tags_2 + assert "<|channel|>commentary to=python" not in tags_2 + + def test_tag_format_consistency(self, gptoss_parser): + """Test that all generated tags follow consistent format.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock( + side_effect=lambda tool: tool in ["python", "browser"] + ) + + result = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_result = json.loads(result) + + # Verify all tags have required fields + for tag in parsed_result["format"]["tags"]: + assert "begin" in tag + assert "content" in tag + assert "end" in tag + assert tag["content"]["type"] == "any_text" + assert tag["end"] == "<|end|>" + + # Verify begin format + assert tag["begin"].startswith("<|channel|>") + + def test_trigger_configuration(self, gptoss_parser): + """Test trigger configuration for different tool setups.""" + # Test with no tools + result_no_tools = gptoss_parser.prepare_structured_tag(None, None) + parsed_no_tools = json.loads(result_no_tools) + assert parsed_no_tools["format"]["triggers"] == ["<|channel|>analysis"] + + # Test with tools + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python") + + result_with_tools = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_with_tools = json.loads(result_with_tools) + + expected_triggers = ["<|channel|>analysis", "<|channel|>commentary to="] + assert set(parsed_with_tools["format"]["triggers"]) == set(expected_triggers) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index cca9729b9d..014e6eca2e 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -864,3 +864,49 @@ def test_structured_output_batched_with_non_structured_outputs_requests( # non-structured outputs requests should not return a valid JSON here with pytest.raises(ValueError): output_json = json.loads(generated_text) + + +@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"]) +def test_structured_output_with_structural_tag( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="Qwen/Qwen2.5-1.5B-Instruct", + guided_decoding_backend=guided_decoding_backend, + ) + + structural_tag_config = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + {"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"} + ], + "triggers": ["hello"], + "stop_after_first": False, + }, + } + + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=500, + guided_decoding=StructuredOutputsParams( + structural_tag=json.dumps(structural_tag_config) + ), + ) + + prompt = "Hello and repete hello 10 times, do not say anything else. Only say hello hello hello, now start" + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + assert generated_text is not None + assert "hello_flag" in generated_text, ( + f"Expected 'hello_flag' to be in generated text, but got: {generated_text}" + ) diff --git a/tests/v1/structured_output/test_gptoss_structural_tags.py b/tests/v1/structured_output/test_gptoss_structural_tags.py new file mode 100644 index 0000000000..f0feabfb99 --- /dev/null +++ b/tests/v1/structured_output/test_gptoss_structural_tags.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for GPT-OSS structural tag support in reasoning (PR #25515).""" + +import json +from unittest.mock import Mock + +import pytest + +from vllm.entrypoints.tool_server import ToolServer +from vllm.reasoning.gptoss_reasoning_parser import ( + GptOssReasoningParser, + from_builtin_tool_to_tag, + no_func_reaonsing_tag, + tag_with_builtin_funcs, +) + + +class TestGptOssReasoningParser: + """Test cases for GptOssReasoningParser structural tag functionality.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer for testing.""" + tokenizer = Mock() + tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + return tokenizer + + @pytest.fixture + def reasoning_parser(self, mock_tokenizer): + """Create a GptOssReasoningParser instance.""" + return GptOssReasoningParser(mock_tokenizer) + + @pytest.fixture + def mock_tool_server_empty(self): + """Create a mock ToolServer with no tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(return_value=False) + return tool_server + + @pytest.fixture + def mock_tool_server_with_browser(self): + """Create a mock ToolServer with browser tool.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "browser") + return tool_server + + @pytest.fixture + def mock_tool_server_with_all_tools(self): + """Create a mock ToolServer with all builtin tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock( + side_effect=lambda tool: tool in ["browser", "python", "container"] + ) + return tool_server + + def test_prepare_structured_tag_no_tool_server(self, reasoning_parser): + """Test prepare_structured_tag with no tool server.""" + result = reasoning_parser.prepare_structured_tag(None, None) + expected = json.dumps(no_func_reaonsing_tag) + + assert result == expected + + # Verify the structure is correct + parsed = json.loads(result) + assert parsed["type"] == "structural_tag" + assert parsed["format"]["type"] == "triggered_tags" + assert len(parsed["format"]["tags"]) == 1 + assert parsed["format"]["tags"][0]["begin"] == "<|channel|>analysis<|message|>" + assert parsed["format"]["triggers"] == ["<|channel|>analysis"] + + def test_prepare_structured_tag_with_all_tools( + self, reasoning_parser, mock_tool_server_with_all_tools + ): + """Test prepare_structured_tag with all builtin tools.""" + result = reasoning_parser.prepare_structured_tag( + None, mock_tool_server_with_all_tools + ) + parsed = json.loads(result) + + # Should have analysis tag + tags for all 3 tools (2 tags each) + assert len(parsed["format"]["tags"]) == 7 # 1 analysis + 6 tool tags + + # Check all tool tags are present + tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]] + for tool in ["browser", "python", "container"]: + assert f"<|channel|>commentary to={tool}" in tag_begins + assert f"<|channel|>analysis to={tool}" in tag_begins + + def test_prepare_structured_tag_with_original_tag(self, reasoning_parser): + """Test prepare_structured_tag when original_tag is provided.""" + original_tag = '{"custom": "tag"}' + result = reasoning_parser.prepare_structured_tag(original_tag, None) + + # Should return the original tag unchanged + assert result == original_tag + + def test_from_builtin_tool_to_tag(self): + """Test from_builtin_tool_to_tag function.""" + tags = from_builtin_tool_to_tag("python") + + assert len(tags) == 2 + assert tags[0]["begin"] == "<|channel|>commentary to=python" + assert tags[0]["content"]["type"] == "any_text" + assert tags[0]["end"] == "<|end|>" + + assert tags[1]["begin"] == "<|channel|>analysis to=python" + assert tags[1]["content"]["type"] == "any_text" + assert tags[1]["end"] == "<|end|>" + + def test_tag_with_builtin_funcs(self): + """Test tag_with_builtin_funcs function.""" + builtin_tools = ["browser", "python"] + result = tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tools) + + assert result["type"] == "structural_tag" + # Should have original analysis tag + 2 tags per tool + assert len(result["format"]["tags"]) == 5 # 1 + 2*2 + + # Should have added commentary trigger + assert "<|channel|>commentary to=" in result["format"]["triggers"] + assert "<|channel|>analysis" in result["format"]["triggers"] + + def test_tag_structure_invariants(self): + """Test that the basic tag structure follows expected format.""" + # Test the base no_func_reaonsing_tag structure + assert no_func_reaonsing_tag["type"] == "structural_tag" + assert no_func_reaonsing_tag["format"]["type"] == "triggered_tags" + assert no_func_reaonsing_tag["format"]["stop_after_first"] is False + + # Verify analysis tag structure + analysis_tag = no_func_reaonsing_tag["format"]["tags"][0] + assert analysis_tag["begin"] == "<|channel|>analysis<|message|>" + assert analysis_tag["content"]["type"] == "any_text" + assert analysis_tag["end"] == "<|end|>" + + def test_json_serialization_valid( + self, reasoning_parser, mock_tool_server_with_all_tools + ): + """Test that all generated tags produce valid JSON.""" + # Test with no tool server + result1 = reasoning_parser.prepare_structured_tag(None, None) + json.loads(result1) # Should not raise + + # Test with empty tool server + empty_server = Mock(spec=ToolServer) + empty_server.has_tool = Mock(return_value=False) + result2 = reasoning_parser.prepare_structured_tag(None, empty_server) + json.loads(result2) # Should not raise + + # Test with tools + result3 = reasoning_parser.prepare_structured_tag( + None, mock_tool_server_with_all_tools + ) + json.loads(result3) # Should not raise + + @pytest.mark.parametrize("tool_name", ["browser", "python", "container"]) + def test_single_tool_integration(self, reasoning_parser, tool_name): + """Test integration with individual tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == tool_name) + + result = reasoning_parser.prepare_structured_tag(None, tool_server) + parsed = json.loads(result) + + # Should have 1 analysis + 2 tool-specific tags + assert len(parsed["format"]["tags"]) == 3 + + tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]] + assert f"<|channel|>commentary to={tool_name}" in tag_begins + assert f"<|channel|>analysis to={tool_name}" in tag_begins diff --git a/tests/v1/structured_output/test_reasoning_structured_output.py b/tests/v1/structured_output/test_reasoning_structured_output.py new file mode 100644 index 0000000000..70047a993c --- /dev/null +++ b/tests/v1/structured_output/test_reasoning_structured_output.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for reasoning-aware structured output functionality (PR #25515).""" + +from unittest.mock import Mock + +import pytest + +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.reasoning import ReasoningParser +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + + +class TestReasoningStructuredOutput: + """Test reasoning-aware structured output functionality.""" + + @pytest.fixture + def mock_model_config(self): + """Create a mock ModelConfig.""" + config = Mock(spec=ModelConfig) + config.skip_tokenizer_init = True # Skip tokenizer init to avoid network calls + config.get_vocab_size = Mock(return_value=50000) + # Add missing runner_type attribute that tokenizer initialization expects + config.runner_type = "generate" + # Add other attributes that tokenizer initialization might need + config.tokenizer = "test-tokenizer" + config.tokenizer_mode = "auto" + config.trust_remote_code = False + config.tokenizer_revision = None + return config + + @pytest.fixture + def mock_scheduler_config(self): + """Create a mock SchedulerConfig.""" + config = Mock(spec=SchedulerConfig) + config.max_num_seqs = 128 + return config + + @pytest.fixture + def mock_vllm_config(self, mock_model_config, mock_scheduler_config): + """Create a mock VllmConfig.""" + config = Mock(spec=VllmConfig) + config.model_config = mock_model_config + config.scheduler_config = mock_scheduler_config + config.structured_outputs_config = Mock() + config.structured_outputs_config.reasoning_parser = None + config.structured_outputs_config.enable_in_reasoning = False + config.speculative_config = None + return config + + @pytest.fixture + def mock_reasoning_parser(self): + """Create a mock ReasoningParser.""" + parser = Mock(spec=ReasoningParser) + parser.is_reasoning_end = Mock(return_value=False) + return parser + + @pytest.fixture + def mock_request_with_structured_output(self): + """Create a mock request with structured output.""" + request = Mock(spec=Request) + request.structured_output_request = Mock() + request.structured_output_request.reasoning_ended = None + request.structured_output_request.grammar = Mock() + request.structured_output_request.grammar.is_terminated = Mock( + return_value=False + ) + request.use_structured_output = True + request.prompt_token_ids = [1, 2, 3, 4, 5] + request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8] + return request + + def test_should_fill_bitmask_with_enable_in_reasoning( + self, mock_vllm_config, mock_request_with_structured_output + ): + """Test should_fill_bitmask when enable_in_reasoning is True.""" + # Enable enable_in_reasoning + mock_vllm_config.structured_outputs_config.enable_in_reasoning = True + + manager = StructuredOutputManager(mock_vllm_config) + + # Should always return True when enable_in_reasoning is enabled + result = manager.should_fill_bitmask(mock_request_with_structured_output) + assert result is True + + def test_should_fill_bitmask_without_enable_in_reasoning( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_fill_bitmask when enable_in_reasoning is False.""" + # Keep enable_in_reasoning as False (default) + config = mock_vllm_config.structured_outputs_config + assert config.enable_in_reasoning is False + + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Mock reasoning not ended + mock_reasoning_parser.is_reasoning_end.return_value = False + + result = manager.should_fill_bitmask(mock_request_with_structured_output) + + # Should set reasoning_ended and return its value + assert ( + mock_request_with_structured_output.structured_output_request.reasoning_ended + is False + ) + assert result is False + + def test_should_fill_bitmask_no_reasoner( + self, mock_vllm_config, mock_request_with_structured_output + ): + """Test should_fill_bitmask when no reasoner is configured.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = None + + result = manager.should_fill_bitmask(mock_request_with_structured_output) + + # Should default to True when no reasoner + assert result is True + + def test_should_advance_with_enable_in_reasoning( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when enable_in_reasoning is True.""" + # Enable enable_in_reasoning + mock_vllm_config.structured_outputs_config.enable_in_reasoning = True + + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Should always return True when enable_in_reasoning is enabled + result = manager.should_advance(mock_request_with_structured_output) + assert result is True + + def test_should_advance_reasoning_not_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning has not ended.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as not ended + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = False + mock_reasoning_parser.is_reasoning_end.return_value = False + + result = manager.should_advance(mock_request_with_structured_output) + + # Should return False since reasoning hasn't ended + assert result is False + + def test_should_advance_reasoning_just_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning ends in current step.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as not ended initially, but ends in this step + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = False + mock_reasoning_parser.is_reasoning_end.return_value = True + + result = manager.should_advance(mock_request_with_structured_output) + + # Should set reasoning_ended to True but return False for this step + assert ( + mock_request_with_structured_output.structured_output_request.reasoning_ended + is True + ) + assert result is False + + def test_should_advance_reasoning_already_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning has already ended.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as already ended + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = True + + result = manager.should_advance(mock_request_with_structured_output) + + # Should return True since reasoning has ended + assert result is True diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 5111c9c77d..76b565006e 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -35,6 +35,8 @@ class StructuredOutputsConfig: reasoning_parser: str = "" """Select the reasoning parser depending on the model that you're using. This is used to parse the reasoning content into OpenAI API format.""" + enable_in_reasoning: bool = False + """Whether to use structured input for reasoning.""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 432b1eca45..05958ca523 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -479,6 +479,7 @@ class EngineArgs: VllmConfig, "structured_outputs_config" ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser + # Deprecated guided decoding fields guided_decoding_backend: str | None = None guided_decoding_disable_fallback: bool | None = None diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e73752b9d5..0d27e6707c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -200,7 +200,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): strict: bool | None = None -class StructuralTag(OpenAIBaseModel): +class LegacyStructuralTag(OpenAIBaseModel): begin: str # schema is the field, but that causes conflicts with pydantic so # instead use structural_tag_schema with an alias @@ -208,10 +208,20 @@ class StructuralTag(OpenAIBaseModel): end: str +class LegacyStructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + structures: list[LegacyStructuralTag] + triggers: list[str] + + class StructuralTagResponseFormat(OpenAIBaseModel): type: Literal["structural_tag"] - structures: list[StructuralTag] - triggers: list[str] + format: Any + + +AnyStructuralTagResponseFormat: TypeAlias = ( + LegacyStructuralTagResponseFormat | StructuralTagResponseFormat +) class ResponseFormat(OpenAIBaseModel): @@ -220,7 +230,9 @@ class ResponseFormat(OpenAIBaseModel): json_schema: JsonSchemaResponseFormat | None = None -AnyResponseFormat: TypeAlias = ResponseFormat | StructuralTagResponseFormat +AnyResponseFormat: TypeAlias = ( + ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat +) class StreamOptions(OpenAIBaseModel): @@ -823,7 +835,11 @@ class ChatCompletionRequest(OpenAIBaseModel): elif response_format.type == "structural_tag": structural_tag = response_format assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat + structural_tag, + ( + LegacyStructuralTagResponseFormat, + StructuralTagResponseFormat, + ), ) s_tag_obj = structural_tag.model_dump(by_alias=True) self.structured_outputs.structural_tag = json.dumps(s_tag_obj) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index ffc692f099..1fdb6997bc 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -98,7 +98,7 @@ from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -365,6 +365,19 @@ class OpenAIServingResponses(OpenAIServing): context = HarmonyContext(messages, available_tools) else: context = SimpleContext() + + if self.reasoning_parser is not None: + reasoning_parser = self.reasoning_parser(tokenizer) + if sampling_params.structured_outputs is None: + sampling_params.structured_outputs = StructuredOutputsParams() + struct_out = sampling_params.structured_outputs + if struct_out.all_non_structural_tag_constraints_none(): + sampling_params.structured_outputs.structural_tag = ( + reasoning_parser.prepare_structured_tag( + sampling_params.structured_outputs.structural_tag, + self.tool_server, + ) + ) generator = self._generate_with_builtin_tools( request_id=request.request_id, request_prompt=request_prompts[i], @@ -1919,6 +1932,7 @@ class OpenAIServingResponses(OpenAIServing): processer = self._process_harmony_streaming_events else: processer = self._process_simple_streaming_events + # TODO Hanchen make sampling params to include the structural tag initial_response = ResponsesResponse.from_request( request, diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 3a595a3076..ee890e662e 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -7,6 +7,7 @@ from collections.abc import Callable, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any +from vllm.entrypoints.tool_server import ToolServer from vllm.logger import init_logger from vllm.utils.collections import is_list_of from vllm.utils.import_utils import import_from_path @@ -115,6 +116,17 @@ class ReasoningParser: previously been parsed and extracted (see constructor) """ + def prepare_structured_tag( + self, + original_tag: str | None, + tool_server: ToolServer | None, + ) -> str: + """ + Instance method that is implemented for preparing the structured tag + Otherwise, None is returned + """ + return None + class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index ccb2d9553c..e6766ddcbc 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -1,17 +1,61 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import json from collections.abc import Sequence from transformers import PreTrainedTokenizerBase from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.entrypoints.tool_server import ToolServer from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) +no_func_reaonsing_tag = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + { + "begin": "<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + } + ], + "triggers": ["<|channel|>analysis"], + "stop_after_first": False, + }, +} + + +def from_builtin_tool_to_tag(tool: str) -> list[dict]: + tag = [ + { + "begin": f"<|channel|>commentary to={tool}", + "content": {"type": "any_text"}, + "end": "<|end|>", + }, + { + "begin": f"<|channel|>analysis to={tool}", + "content": {"type": "any_text"}, + "end": "<|end|>", + }, + ] + return tag + + +def tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list: list[str]) -> dict: + import copy + + new_tag = copy.deepcopy(no_func_reaonsing_tag) + new_tag["format"]["triggers"].append("<|channel|>commentary to=") + + for tool in builtin_tool_list: + new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool)) + return new_tag + @ReasoningParserManager.register_module("openai_gptoss") class GptOssReasoningParser(ReasoningParser): @@ -81,3 +125,33 @@ class GptOssReasoningParser(ReasoningParser): raise NotImplementedError( "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 ) + + # This function prepares the structural tag to format reasoning output + def prepare_structured_tag( + self, original_tag: str | None, tool_server: ToolServer | None + ) -> str: + if original_tag is None: + if tool_server is None: + return json.dumps(no_func_reaonsing_tag) + else: + builtin_tool_list: list[str] = [] + if tool_server.has_tool("browser"): + builtin_tool_list.append("browser") + if tool_server.has_tool("python"): + builtin_tool_list.append("python") + if tool_server.has_tool("container"): + builtin_tool_list.append("container") + + if len(builtin_tool_list) > 0: + logger.info("Builtin_tool_list: %s", builtin_tool_list) + func_tag = json.dumps( + tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list) + ) + else: + logger.info("Builtin_tool_list is empty") + func_tag = json.dumps(no_func_reaonsing_tag) + + return func_tag + else: + # There is potential risk for appending the tag to the original tag + return original_tag diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 3f583b393e..4b2a3bc4db 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -58,6 +58,7 @@ class StructuredOutputsParams: self.choice is not None, self.grammar is not None, self.json_object is not None, + self.structural_tag is not None, ] ) if count > 1: @@ -66,6 +67,37 @@ class StructuredOutputsParams: f"but multiple are specified: {self.__dict__}" ) + def all_constraints_none(self) -> bool: + """ + Returns True if all structured-output constraint fields are None. + """ + return all( + getattr(self, field) is None + for field in ( + "json", + "regex", + "choice", + "grammar", + "json_object", + "structural_tag", + ) + ) + + def all_non_structural_tag_constraints_none(self) -> bool: + """ + Returns True if all structured-output constraint fields are None. + """ + return all( + getattr(self, field) is None + for field in ( + "json", + "regex", + "choice", + "grammar", + "json_object", + ) + ) + @dataclass class GuidedDecodingParams(StructuredOutputsParams): diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4fb26ab1ce..6f9dbeabd8 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -73,6 +73,10 @@ class StructuredOutputManager: ) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + self.enable_in_reasoning = ( + self.vllm_config.structured_outputs_config.enable_in_reasoning + ) + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return @@ -274,7 +278,13 @@ class StructuredOutputManager: return bitmask_tensor.numpy() def should_fill_bitmask(self, request: Request) -> bool: + # NOTE (Hanchen) if enable_in_reasoning is True, it means that + # the model needs to be constrained in reasoning. So we should always + # enable the bitmask filling. + if self.reasoner is not None: + if self.enable_in_reasoning: + return True assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: request.structured_output_request.reasoning_ended = ( @@ -297,6 +307,10 @@ class StructuredOutputManager: if self.reasoner is None: return True + # if the model needs structured in reasoning, we should advance + if self.enable_in_reasoning: + return True + structured_req = request.structured_output_request if structured_req.reasoning_ended: return True diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 074781fb66..4fe4f8848d 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -91,18 +91,20 @@ class XgrammarBackend(StructuredOutputBackend): ctx = self.compiler.compile_regex(grammar_spec) elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: s_tag = json.loads(grammar_spec) - tags = [ - xgr.StructuralTagItem( - begin=s["begin"], - schema=json.dumps(s["schema"]), - end=s["end"], - ) - for s in s_tag["structures"] - ] - structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"] - ) - ctx = self.compiler.compile_structural_tag(structural_tag) + if "structures" in s_tag: + # Falling back to deprecated method of compiling structural tag + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) + for s in s_tag["structures"] + ] + ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) + else: + logger.info("Compiling structural tag grammar_spec: %s", grammar_spec) + ctx = self.compiler.compile_structural_tag(grammar_spec) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -320,17 +322,19 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: if so_params.structural_tag: try: s_tag = json.loads(so_params.structural_tag) - tags = [ - xgr.StructuralTagItem( - begin=s["begin"], - schema=json.dumps(s["schema"]), - end=s["end"], - ) - for s in s_tag["structures"] - ] - structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"] - ) - xgr.Grammar.from_structural_tag(structural_tag) + + # Using the deprecated method of compiling structural tag + if "structures" in s_tag: + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) + for s in s_tag["structures"] + ] + xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + else: + xgr.Grammar.from_structural_tag(so_params.structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index afe0e4b3f3..94ae36a1ab 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -28,7 +28,12 @@ class StructuredOutputRequest: if sampling_params is None: return None params = sampling_params.structured_outputs - return StructuredOutputRequest(params=params) if params else None + if params: + if params.all_constraints_none(): + return None + else: + return StructuredOutputRequest(params=params) + return None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports