[Frontend] Tool calling parser for Granite 3.0 models (#9027)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-07 12:09:02 -03:00
committed by GitHub
parent a62bc0109c
commit ae62fd17c0
6 changed files with 314 additions and 33 deletions

View File

@ -160,14 +160,7 @@ this, unless explicitly specified.
:func: create_parser_for_docs :func: create_parser_for_docs
:prog: vllm serve :prog: vllm serve
``` ```
## Tool Calling in the Chat Completion API
### Named Function Calling
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
### Config file ### Config file
@ -196,12 +189,22 @@ The order of priorities is `command line > config file values > defaults`.
--- ---
## Tool calling in the chat completion API ## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
vLLM supports named function calling and `auto` tool choice in the chat completion API. The `tool_choice` options `required` is **not yet supported** but on the roadmap.
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.
### Named Function Calling
vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
### Automatic Function Calling ### Automatic Function Calling
To enable this feature, you should set the following flags: To enable this feature, you should set the following flags:
@ -275,6 +278,21 @@ it works better with vLLM.
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
#### IBM Granite
Supported models:
* `ibm-granite/granite-3.0-8b-instruct`
Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`
`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.
* `ibm-granite/granite-20b-functioncalling`
Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
#### InternLM Models (`internlm`) #### InternLM Models (`internlm`)
@ -297,16 +315,6 @@ AI21's Jamba-1.5 models are supported.
Flags: `--tool-call-parser jamba` Flags: `--tool-call-parser jamba`
#### IBM Granite (`granite-20b-fc`)
Supported models:
* `ibm-granite/granite-20b-functioncalling`
Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
### How to write a tool parser plugin ### How to write a tool parser plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.

View File

@ -0,0 +1,40 @@
{%- if tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>
' }}
{%- for tool in tools %}
{{- tool | tojson(indent=4) }}
{%- if not loop.last %}
{{- '
' }}
{%- endif %}
{%- endfor %}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'user' %}
{{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %}
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
{% for tc in message.tool_calls %}
{{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }}
{% endfor %}
{{- '<|end_of_text|>
' }}
{%- elif message['role'] == 'assistant' %}
{{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %}
{{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- endif %}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View File

@ -3,6 +3,7 @@ import pytest_asyncio
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
from .utils import ARGS, CONFIGS, ServerConfig from .utils import ARGS, CONFIGS, ServerConfig
@ -11,6 +12,11 @@ from .utils import ARGS, CONFIGS, ServerConfig
@pytest.fixture(scope="session", params=CONFIGS.keys()) @pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request): def server_config(request):
config = CONFIGS[request.param] config = CONFIGS[request.param]
if current_platform.is_rocm() and not config.get("supports_rocm", True):
pytest.skip("The {} model can't be tested on the ROCm platform".format(
config["model"]))
# download model and tokenizer using transformers # download model and tokenizer using transformers
snapshot_download(config["model"]) snapshot_download(config["model"])
yield CONFIGS[request.param] yield CONFIGS[request.param]

View File

@ -13,6 +13,7 @@ class ServerConfig(TypedDict, total=False):
arguments: List[str] arguments: List[str]
system_prompt: Optional[str] system_prompt: Optional[str]
supports_parallel: Optional[bool] supports_parallel: Optional[bool]
supports_rocm: Optional[bool]
def patch_system_prompt(messages: List[Dict[str, Any]], def patch_system_prompt(messages: List[Dict[str, Any]],
@ -36,7 +37,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
# universal args for all models go here. also good if you need to test locally # universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something. # and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
CONFIGS: Dict[str, ServerConfig] = { CONFIGS: Dict[str, ServerConfig] = {
"hermes": { "hermes": {
@ -88,18 +89,28 @@ CONFIGS: Dict[str, ServerConfig] = {
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally." "to the user's question - just respond to it normally."
}, },
## FIXME: temporary disabled due to lack of hardware specification "granite20b": {
## for individual runs "model":
#"granite20b": { "mbayser/granite-20b-functioncalling-FP8-KV",
# "model": "arguments": [
# "ibm-granite/granite-20b-functioncalling", "--tool-call-parser", "granite-20b-fc", "--chat-template",
# "arguments": [ str(VLLM_PATH /
# "--tool-call-parser", "granite-20b-fc", "--chat-template", "examples/tool_chat_template_granite_20b_fc.jinja"),
# str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja") "--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20"
# ], ],
# "supports_parallel": "supports_parallel":
# False, False,
#}, "supports_rocm":
False,
},
"granite8b": {
"model":
"ibm-granite/granite-3.0-8b-instruct",
"arguments": [
"--tool-call-parser", "granite", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja")
],
},
"internlm": { "internlm": {
"model": "model":
"internlm/internlm2_5-7b-chat", "internlm/internlm2_5-7b-chat",

View File

@ -1,5 +1,6 @@
from .abstract_tool_parser import ToolParser, ToolParserManager from .abstract_tool_parser import ToolParser, ToolParserManager
from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
from .hermes_tool_parser import Hermes2ProToolParser from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser from .jamba_tool_parser import JambaToolParser
@ -8,6 +9,6 @@ from .mistral_tool_parser import MistralToolParser
__all__ = [ __all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser", "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Llama3JsonToolParser", "JambaToolParser" "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser"
] ]

View File

@ -0,0 +1,215 @@
import json
from typing import Dict, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("granite")
class GraniteToolParser(ToolParser):
"""
Tool call parser for the granite 3.0 models. Intended
for use with the examples/tool_chat_template_granite.jinja
template.
Used when --enable-auto-tool-choice --tool-call-parser granite
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
stripped = model_output.strip()
if not stripped or stripped[0] != '[':
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
raw_function_calls = json.loads(stripped)
if not isinstance(raw_function_calls, list):
raise Exception(
f"Expected dict or list, got {type(raw_function_calls)}")
logger.debug("Extracted %d tool calls", len(raw_function_calls))
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]),
),
) for function_call in raw_function_calls
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=None,
)
except Exception as e:
logger.error("Error in extracting tool call from response %s", e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
start_idx = consume_space(0, current_text)
if not current_text or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = None
is_complete = None
try:
tool_calls, end_idx = partial_json_loads(
current_text[start_idx:], flags)
if type(tool_calls) is list:
tool_call_arr = tool_calls
else:
return DeltaMessage(content=delta_text)
is_complete = [True] * len(tool_calls)
if not is_complete_json(
current_text[start_idx:start_idx + end_idx]):
is_complete[-1] = False
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if not tool_call_arr:
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id]
delta = None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
if len(tool_call_arr) > self.current_tool_id + 1:
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None