mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Olmo 3 tool parser and tests (#26143)
Signed-off-by: Pradeep Dasigi <pradeepd@allenai.org>
This commit is contained in:
@ -352,6 +352,16 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser qwen3_xml`
|
||||
|
||||
### Olmo 3 Models (`olmo3`)
|
||||
|
||||
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
|
||||
|
||||
Supported models:
|
||||
|
||||
* TODO (will be updated after Olmo 3 release)
|
||||
|
||||
Flags: `--tool-call-parser olmo3`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
243
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
243
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=null, "
|
||||
"passed_test=true, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
model_output_deltas = [
|
||||
"<function_calls>get_weather(city='San",
|
||||
" Francisco', metric='celsius')\n"
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
@ -18,6 +18,7 @@ from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .longcat_tool_parser import LongcatFlashToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .olmo3_tool_parser import Olmo3PythonicToolParser
|
||||
from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"DeepSeekV31ToolParser",
|
||||
"Ernie45ToolParser",
|
||||
"xLAMToolParser",
|
||||
"Olmo3PythonicToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
|
368
vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py
Normal file
368
vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py
Normal file
@ -0,0 +1,368 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("olmo3")
|
||||
class Olmo3PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Olmo 3 models that produce tool calls as
|
||||
newline-separated pythonic strings.
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
Code copied from pythonic_tool_parser.py and updated to handle
|
||||
- newline separated pythonic tool calls.
|
||||
- argument values being null/true/false instead of Pythonic literals.
|
||||
"""
|
||||
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
original_model_output = model_output
|
||||
# Remove xml tags.
|
||||
match = re.search(
|
||||
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
|
||||
)
|
||||
if match:
|
||||
model_output = match.group(1).strip()
|
||||
# Make the newline separated function calls into a list.
|
||||
model_output = ", ".join(
|
||||
[line.strip() for line in model_output.splitlines() if line.strip()]
|
||||
)
|
||||
model_output = f"[{model_output}]"
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = (
|
||||
self.TOOL_CALL_REGEX.match(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
)
|
||||
is not None
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# All function calls start with the <function_calls> tag.
|
||||
# But since this is streaming, we may have seen only part of the tag.
|
||||
if not current_text.startswith("<"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# Remove xml tags.
|
||||
if current_text.startswith("<function_calls>"):
|
||||
current_text = current_text[len("<function_calls>") :]
|
||||
if current_text.endswith("</function_calls>"):
|
||||
current_text = current_text[: -len("</function_calls>")]
|
||||
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
# Make the newline separated function calls into a list.
|
||||
valid_text = ", ".join(
|
||||
[line.strip() for line in valid_text.splitlines() if line.strip()]
|
||||
)
|
||||
valid_text = f"[{valid_text}]"
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a sequence of newline-separated calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = added_text[:-1] if not new_call_complete else ""
|
||||
if not new_call_complete and added_text[-1] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (
|
||||
delta.function is not None
|
||||
and delta.function.arguments is not None
|
||||
):
|
||||
self.streamed_args_for_tool[index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
# The model may return function calls where the values are null/true/false
|
||||
# because the system prompt has API description in json.
|
||||
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
|
||||
if val.id == "null":
|
||||
return None
|
||||
elif val.id == "true":
|
||||
return True
|
||||
elif val.id == "false":
|
||||
return False
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
Reference in New Issue
Block a user