Olmo 3 tool parser and tests (#26143)

Signed-off-by: Pradeep Dasigi <pradeepd@allenai.org>
This commit is contained in:
Pradeep Dasigi
2025-10-15 09:36:12 -07:00
committed by GitHub
parent d3cbaa08dc
commit 4794c2bd92
4 changed files with 623 additions and 0 deletions

View File

@ -352,6 +352,16 @@ Supported models:
Flags: `--tool-call-parser qwen3_xml`
### Olmo 3 Models (`olmo3`)
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
Supported models:
* TODO (will be updated after Olmo 3 release)
Flags: `--tool-call-parser olmo3`
### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=null, "
"passed_test=true, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming_json_literals",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming_json_literals",
),
pytest.param(
True,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
model_output_deltas = [
"<function_calls>get_weather(city='San",
" Francisco', metric='celsius')\n"
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@ -18,6 +18,7 @@ from .llama_tool_parser import Llama3JsonToolParser
from .longcat_tool_parser import LongcatFlashToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .olmo3_tool_parser import Olmo3PythonicToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
@ -45,6 +46,7 @@ __all__ = [
"DeepSeekV31ToolParser",
"Ernie45ToolParser",
"xLAMToolParser",
"Olmo3PythonicToolParser",
"MinimaxToolParser",
"KimiK2ToolParser",
"HunyuanA13BToolParser",

View File

@ -0,0 +1,368 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("olmo3")
class Olmo3PythonicToolParser(ToolParser):
"""
Tool call parser for Olmo 3 models that produce tool calls as
newline-separated pythonic strings.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
Code copied from pythonic_tool_parser.py and updated to handle
- newline separated pythonic tool calls.
- argument values being null/true/false instead of Pythonic literals.
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
original_model_output = model_output
# Remove xml tags.
match = re.search(
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
)
if match:
model_output = match.group(1).strip()
# Make the newline separated function calls into a list.
model_output = ", ".join(
[line.strip() for line in model_output.splitlines() if line.strip()]
)
model_output = f"[{model_output}]"
is_tool_call_pattern = False
try:
is_tool_call_pattern = (
self.TOOL_CALL_REGEX.match(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
)
is not None
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts
):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# All function calls start with the <function_calls> tag.
# But since this is streaming, we may have seen only part of the tag.
if not current_text.startswith("<"):
return DeltaMessage(content=delta_text)
try:
# Remove xml tags.
if current_text.startswith("<function_calls>"):
current_text = current_text[len("<function_calls>") :]
if current_text.endswith("</function_calls>"):
current_text = current_text[: -len("</function_calls>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
# Make the newline separated function calls into a list.
valid_text = ", ".join(
[line.strip() for line in valid_text.splitlines() if line.strip()]
)
valid_text = f"[{valid_text}]"
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a sequence of newline-separated calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = added_text[:-1] if not new_call_complete else ""
if not new_call_complete and added_text[-1] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
if delta is not None:
tool_deltas.append(delta)
if (
delta.function is not None
and delta.function.arguments is not None
):
self.streamed_args_for_tool[index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content="")
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
# The model may return function calls where the values are null/true/false
# because the system prompt has API description in json.
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
if val.id == "null":
return None
elif val.id == "true":
return True
elif val.id == "false":
return False
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)