mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Frontend] Tool calling parser for Granite 3.0 models (#9027)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							a62bc0109c
						
					
				
				
					commit
					ae62fd17c0
				
			| @ -160,14 +160,7 @@ this, unless explicitly specified. | |||||||
| :func: create_parser_for_docs | :func: create_parser_for_docs | ||||||
| :prog: vllm serve | :prog: vllm serve | ||||||
| ``` | ``` | ||||||
| ## Tool Calling in the Chat Completion API |  | ||||||
| ### Named Function Calling |  | ||||||
| vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is  |  | ||||||
| enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a  |  | ||||||
| high-quality one.  |  | ||||||
|  |  | ||||||
| To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and  |  | ||||||
| specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.  |  | ||||||
|  |  | ||||||
| ### Config file | ### Config file | ||||||
|  |  | ||||||
| @ -196,12 +189,22 @@ The order of priorities is `command line > config file values > defaults`. | |||||||
| --- | --- | ||||||
|  |  | ||||||
| ## Tool calling in the chat completion API | ## Tool calling in the chat completion API | ||||||
| vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. |  | ||||||
|  | vLLM supports named function calling and `auto` tool choice  in the chat completion API. The `tool_choice` options `required` is **not yet supported** but on the roadmap. | ||||||
|  |  | ||||||
| It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. | It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### Named Function Calling | ||||||
|  | vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is  | ||||||
|  | enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a  | ||||||
|  | high-quality one.  | ||||||
|  |  | ||||||
| vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. | vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. | ||||||
|  |  | ||||||
|  | To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and  | ||||||
|  | specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.  | ||||||
|  |  | ||||||
|  |  | ||||||
| ### Automatic Function Calling | ### Automatic Function Calling | ||||||
| To enable this feature, you should set the following flags: | To enable this feature, you should set the following flags: | ||||||
| @ -275,6 +278,21 @@ it works better with vLLM. | |||||||
|  |  | ||||||
| Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` | Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` | ||||||
|  |  | ||||||
|  | #### IBM Granite | ||||||
|  |  | ||||||
|  | Supported models: | ||||||
|  | * `ibm-granite/granite-3.0-8b-instruct` | ||||||
|  |  | ||||||
|  | Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` | ||||||
|  |  | ||||||
|  | `examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. | ||||||
|  |  | ||||||
|  | * `ibm-granite/granite-20b-functioncalling` | ||||||
|  |  | ||||||
|  | Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` | ||||||
|  |  | ||||||
|  | `examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. | ||||||
|  |  | ||||||
|  |  | ||||||
| #### InternLM Models (`internlm`) | #### InternLM Models (`internlm`) | ||||||
|  |  | ||||||
| @ -297,16 +315,6 @@ AI21's Jamba-1.5 models are supported. | |||||||
| Flags: `--tool-call-parser jamba` | Flags: `--tool-call-parser jamba` | ||||||
|  |  | ||||||
|  |  | ||||||
| #### IBM Granite (`granite-20b-fc`) |  | ||||||
|  |  | ||||||
| Supported models: |  | ||||||
| * `ibm-granite/granite-20b-functioncalling` |  | ||||||
|  |  | ||||||
| Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` |  | ||||||
|  |  | ||||||
| The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ### How to write a tool parser plugin | ### How to write a tool parser plugin | ||||||
|  |  | ||||||
| A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. | A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. | ||||||
|  | |||||||
							
								
								
									
										40
									
								
								examples/tool_chat_template_granite.jinja
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								examples/tool_chat_template_granite.jinja
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | |||||||
|  | {%- if tools %} | ||||||
|  |     {{- '<|start_of_role|>available_tools<|end_of_role|> | ||||||
|  | ' }} | ||||||
|  |     {%- for tool in tools %} | ||||||
|  |     {{- tool | tojson(indent=4) }} | ||||||
|  |     {%- if not loop.last %} | ||||||
|  |         {{- ' | ||||||
|  |  | ||||||
|  | ' }} | ||||||
|  |     {%- endif %} | ||||||
|  |     {%- endfor %} | ||||||
|  |     {{- '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  | {%- endif %} | ||||||
|  |  | ||||||
|  | {%- for message in messages %} | ||||||
|  |     {%- if message['role'] == 'system' %} | ||||||
|  |     {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  |     {%- elif message['role'] == 'user' %} | ||||||
|  |     {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  |     {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} | ||||||
|  |     {{- '<|start_of_role|>assistant<|end_of_role|>' }} | ||||||
|  |         {% for tc in message.tool_calls %} | ||||||
|  |             {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson  }} | ||||||
|  |         {% endfor %} | ||||||
|  |     {{- '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  |     {%- elif message['role'] == 'assistant' %} | ||||||
|  |     {{- '<|start_of_role|>assistant<|end_of_role|>'  + message['content'] + '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  |     {%- elif message['role'] == 'tool_response' or  message['role'] == 'tool' %} | ||||||
|  |     {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|> | ||||||
|  | ' }} | ||||||
|  |     {%- endif %} | ||||||
|  |     {%- if loop.last and add_generation_prompt %} | ||||||
|  |     {{- '<|start_of_role|>assistant<|end_of_role|>' }} | ||||||
|  |     {%- endif %} | ||||||
|  | {%- endfor %} | ||||||
| @ -3,6 +3,7 @@ import pytest_asyncio | |||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||||
|  |  | ||||||
| from tests.utils import RemoteOpenAIServer | from tests.utils import RemoteOpenAIServer | ||||||
|  | from vllm.platforms import current_platform | ||||||
|  |  | ||||||
| from .utils import ARGS, CONFIGS, ServerConfig | from .utils import ARGS, CONFIGS, ServerConfig | ||||||
|  |  | ||||||
| @ -11,6 +12,11 @@ from .utils import ARGS, CONFIGS, ServerConfig | |||||||
| @pytest.fixture(scope="session", params=CONFIGS.keys()) | @pytest.fixture(scope="session", params=CONFIGS.keys()) | ||||||
| def server_config(request): | def server_config(request): | ||||||
|     config = CONFIGS[request.param] |     config = CONFIGS[request.param] | ||||||
|  |  | ||||||
|  |     if current_platform.is_rocm() and not config.get("supports_rocm", True): | ||||||
|  |         pytest.skip("The {} model can't be tested on the ROCm platform".format( | ||||||
|  |             config["model"])) | ||||||
|  |  | ||||||
|     # download model and tokenizer using transformers |     # download model and tokenizer using transformers | ||||||
|     snapshot_download(config["model"]) |     snapshot_download(config["model"]) | ||||||
|     yield CONFIGS[request.param] |     yield CONFIGS[request.param] | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ class ServerConfig(TypedDict, total=False): | |||||||
|     arguments: List[str] |     arguments: List[str] | ||||||
|     system_prompt: Optional[str] |     system_prompt: Optional[str] | ||||||
|     supports_parallel: Optional[bool] |     supports_parallel: Optional[bool] | ||||||
|  |     supports_rocm: Optional[bool] | ||||||
|  |  | ||||||
|  |  | ||||||
| def patch_system_prompt(messages: List[Dict[str, Any]], | def patch_system_prompt(messages: List[Dict[str, Any]], | ||||||
| @ -36,7 +37,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], | |||||||
|  |  | ||||||
| # universal args for all models go here. also good if you need to test locally | # universal args for all models go here. also good if you need to test locally | ||||||
| # and change type or KV cache quantization or something. | # and change type or KV cache quantization or something. | ||||||
| ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] | ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] | ||||||
|  |  | ||||||
| CONFIGS: Dict[str, ServerConfig] = { | CONFIGS: Dict[str, ServerConfig] = { | ||||||
|     "hermes": { |     "hermes": { | ||||||
| @ -88,18 +89,28 @@ CONFIGS: Dict[str, ServerConfig] = { | |||||||
|         "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " |         "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " | ||||||
|         "to the user's question - just respond to it normally." |         "to the user's question - just respond to it normally." | ||||||
|     }, |     }, | ||||||
|     ## FIXME: temporary disabled due to lack of hardware specification |     "granite20b": { | ||||||
|     ## for individual runs |         "model": | ||||||
|     #"granite20b": { |         "mbayser/granite-20b-functioncalling-FP8-KV", | ||||||
|     #    "model": |         "arguments": [ | ||||||
|     #    "ibm-granite/granite-20b-functioncalling", |             "--tool-call-parser", "granite-20b-fc", "--chat-template", | ||||||
|     #    "arguments": [ |             str(VLLM_PATH / | ||||||
|     #        "--tool-call-parser", "granite-20b-fc", "--chat-template", |                 "examples/tool_chat_template_granite_20b_fc.jinja"), | ||||||
|     #        str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja") |             "--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20" | ||||||
|     #    ], |         ], | ||||||
|     #    "supports_parallel": |         "supports_parallel": | ||||||
|     #    False, |         False, | ||||||
|     #}, |         "supports_rocm": | ||||||
|  |         False, | ||||||
|  |     }, | ||||||
|  |     "granite8b": { | ||||||
|  |         "model": | ||||||
|  |         "ibm-granite/granite-3.0-8b-instruct", | ||||||
|  |         "arguments": [ | ||||||
|  |             "--tool-call-parser", "granite", "--chat-template", | ||||||
|  |             str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") | ||||||
|  |         ], | ||||||
|  |     }, | ||||||
|     "internlm": { |     "internlm": { | ||||||
|         "model": |         "model": | ||||||
|         "internlm/internlm2_5-7b-chat", |         "internlm/internlm2_5-7b-chat", | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| from .abstract_tool_parser import ToolParser, ToolParserManager | from .abstract_tool_parser import ToolParser, ToolParserManager | ||||||
| from .granite_20b_fc_tool_parser import Granite20bFCToolParser | from .granite_20b_fc_tool_parser import Granite20bFCToolParser | ||||||
|  | from .granite_tool_parser import GraniteToolParser | ||||||
| from .hermes_tool_parser import Hermes2ProToolParser | from .hermes_tool_parser import Hermes2ProToolParser | ||||||
| from .internlm2_tool_parser import Internlm2ToolParser | from .internlm2_tool_parser import Internlm2ToolParser | ||||||
| from .jamba_tool_parser import JambaToolParser | from .jamba_tool_parser import JambaToolParser | ||||||
| @ -8,6 +9,6 @@ from .mistral_tool_parser import MistralToolParser | |||||||
|  |  | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "ToolParser", "ToolParserManager", "Granite20bFCToolParser", |     "ToolParser", "ToolParserManager", "Granite20bFCToolParser", | ||||||
|     "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", |     "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", | ||||||
|     "Llama3JsonToolParser", "JambaToolParser" |     "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser" | ||||||
| ] | ] | ||||||
|  | |||||||
							
								
								
									
										215
									
								
								vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,215 @@ | |||||||
|  | import json | ||||||
|  | from typing import Dict, Sequence, Union | ||||||
|  |  | ||||||
|  | import partial_json_parser | ||||||
|  | from partial_json_parser.core.options import Allow | ||||||
|  |  | ||||||
|  | from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, | ||||||
|  |                                               DeltaFunctionCall, DeltaMessage, | ||||||
|  |                                               DeltaToolCall, | ||||||
|  |                                               ExtractedToolCallInformation, | ||||||
|  |                                               FunctionCall, ToolCall) | ||||||
|  | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( | ||||||
|  |     ToolParser, ToolParserManager) | ||||||
|  | from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, | ||||||
|  |                                                         find_common_prefix, | ||||||
|  |                                                         is_complete_json, | ||||||
|  |                                                         partial_json_loads) | ||||||
|  | from vllm.logger import init_logger | ||||||
|  | from vllm.transformers_utils.tokenizer import AnyTokenizer | ||||||
|  | from vllm.utils import random_uuid | ||||||
|  |  | ||||||
|  | logger = init_logger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ToolParserManager.register_module("granite") | ||||||
|  | class GraniteToolParser(ToolParser): | ||||||
|  |     """ | ||||||
|  |     Tool call parser for the granite 3.0 models. Intended | ||||||
|  |     for use with the examples/tool_chat_template_granite.jinja | ||||||
|  |     template. | ||||||
|  |  | ||||||
|  |     Used when --enable-auto-tool-choice --tool-call-parser granite | ||||||
|  |     are all set | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, tokenizer: AnyTokenizer): | ||||||
|  |         super().__init__(tokenizer) | ||||||
|  |  | ||||||
|  |     def extract_tool_calls( | ||||||
|  |             self, model_output: str, | ||||||
|  |             request: ChatCompletionRequest) -> ExtractedToolCallInformation: | ||||||
|  |         stripped = model_output.strip() | ||||||
|  |         if not stripped or stripped[0] != '[': | ||||||
|  |             return ExtractedToolCallInformation(tools_called=False, | ||||||
|  |                                                 tool_calls=[], | ||||||
|  |                                                 content=model_output) | ||||||
|  |         try: | ||||||
|  |             raw_function_calls = json.loads(stripped) | ||||||
|  |             if not isinstance(raw_function_calls, list): | ||||||
|  |                 raise Exception( | ||||||
|  |                     f"Expected dict or list, got {type(raw_function_calls)}") | ||||||
|  |  | ||||||
|  |             logger.debug("Extracted %d tool calls", len(raw_function_calls)) | ||||||
|  |             tool_calls = [ | ||||||
|  |                 ToolCall( | ||||||
|  |                     type="function", | ||||||
|  |                     function=FunctionCall( | ||||||
|  |                         name=function_call["name"], | ||||||
|  |                         # function call args are JSON but as a string | ||||||
|  |                         arguments=json.dumps(function_call["arguments"]), | ||||||
|  |                     ), | ||||||
|  |                 ) for function_call in raw_function_calls | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |             return ExtractedToolCallInformation( | ||||||
|  |                 tools_called=True, | ||||||
|  |                 tool_calls=tool_calls, | ||||||
|  |                 content=None, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error("Error in extracting tool call from response %s", e) | ||||||
|  |             return ExtractedToolCallInformation(tools_called=False, | ||||||
|  |                                                 tool_calls=[], | ||||||
|  |                                                 content=model_output) | ||||||
|  |  | ||||||
|  |     def extract_tool_calls_streaming( | ||||||
|  |         self, | ||||||
|  |         previous_text: str, | ||||||
|  |         current_text: str, | ||||||
|  |         delta_text: str, | ||||||
|  |         previous_token_ids: Sequence[int], | ||||||
|  |         current_token_ids: Sequence[int], | ||||||
|  |         delta_token_ids: Sequence[int], | ||||||
|  |         request: ChatCompletionRequest, | ||||||
|  |     ) -> Union[DeltaMessage, None]: | ||||||
|  |  | ||||||
|  |         start_idx = consume_space(0, current_text) | ||||||
|  |         if not current_text or current_text[start_idx] != '[': | ||||||
|  |             return DeltaMessage(content=delta_text) | ||||||
|  |  | ||||||
|  |         # bit mask flags for partial JSON parsing. If the name hasn't been | ||||||
|  |         # sent yet, don't allow sending | ||||||
|  |         # an incomplete string since OpenAI only ever (as far as I have | ||||||
|  |         # seen) allows sending the entire tool/ function name at once. | ||||||
|  |         flags = Allow.ALL if self.current_tool_name_sent \ | ||||||
|  |             else Allow.ALL & ~Allow.STR | ||||||
|  |         try: | ||||||
|  |             tool_call_arr = None | ||||||
|  |             is_complete = None | ||||||
|  |             try: | ||||||
|  |                 tool_calls, end_idx = partial_json_loads( | ||||||
|  |                     current_text[start_idx:], flags) | ||||||
|  |                 if type(tool_calls) is list: | ||||||
|  |                     tool_call_arr = tool_calls | ||||||
|  |                 else: | ||||||
|  |                     return DeltaMessage(content=delta_text) | ||||||
|  |  | ||||||
|  |                 is_complete = [True] * len(tool_calls) | ||||||
|  |                 if not is_complete_json( | ||||||
|  |                         current_text[start_idx:start_idx + end_idx]): | ||||||
|  |                     is_complete[-1] = False | ||||||
|  |             except partial_json_parser.core.exceptions.MalformedJSON: | ||||||
|  |                 logger.debug('not enough tokens to parse into JSON yet') | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|  |             # case -- if no tokens have been streamed for the tool, e.g. | ||||||
|  |             #   only the array brackets, stream nothing | ||||||
|  |             if not tool_call_arr: | ||||||
|  |                 return None | ||||||
|  |  | ||||||
|  |             # select as the current tool call the one we're on the state at | ||||||
|  |             current_tool_call: Dict = tool_call_arr[self.current_tool_id] | ||||||
|  |  | ||||||
|  |             delta = None | ||||||
|  |             # case: we are starting a new tool in the array | ||||||
|  |             #   -> array has > 0 length AND length has moved past cursor | ||||||
|  |             if len(tool_call_arr) > self.current_tool_id + 1: | ||||||
|  |  | ||||||
|  |                 # if we're moving on to a new call, first make sure we | ||||||
|  |                 # haven't missed anything in the previous one that was | ||||||
|  |                 # auto-generated due to JSON completions, but wasn't | ||||||
|  |                 # streamed to the client yet. | ||||||
|  |                 if self.current_tool_id >= 0: | ||||||
|  |                     cur_arguments = current_tool_call.get("arguments") | ||||||
|  |                     if cur_arguments: | ||||||
|  |                         cur_args_json = json.dumps(cur_arguments) | ||||||
|  |                         sent = len( | ||||||
|  |                             self.streamed_args_for_tool[self.current_tool_id]) | ||||||
|  |                         argument_diff = cur_args_json[sent:] | ||||||
|  |  | ||||||
|  |                         logger.debug("got arguments diff: %s", argument_diff) | ||||||
|  |                         delta = DeltaMessage(tool_calls=[ | ||||||
|  |                             DeltaToolCall(index=self.current_tool_id, | ||||||
|  |                                           function=DeltaFunctionCall( | ||||||
|  |                                               arguments=argument_diff). | ||||||
|  |                                           model_dump(exclude_none=True)) | ||||||
|  |                         ]) | ||||||
|  |                         self.streamed_args_for_tool[ | ||||||
|  |                             self.current_tool_id] += argument_diff | ||||||
|  |  | ||||||
|  |                 # re-set stuff pertaining to progress in the current tool | ||||||
|  |                 self.current_tool_id = len(tool_call_arr) - 1 | ||||||
|  |                 self.current_tool_name_sent = False | ||||||
|  |                 self.streamed_args_for_tool.append("") | ||||||
|  |                 logger.debug("starting on new tool %d", self.current_tool_id) | ||||||
|  |                 return delta | ||||||
|  |  | ||||||
|  |             # if the current tool name hasn't been sent, send if available | ||||||
|  |             # - otherwise send nothing | ||||||
|  |             elif not self.current_tool_name_sent: | ||||||
|  |                 function_name = current_tool_call.get("name") | ||||||
|  |                 if function_name: | ||||||
|  |  | ||||||
|  |                     delta = DeltaMessage(tool_calls=[ | ||||||
|  |                         DeltaToolCall(index=self.current_tool_id, | ||||||
|  |                                       type="function", | ||||||
|  |                                       id=f"chatcmpl-tool-{random_uuid()}", | ||||||
|  |                                       function=DeltaFunctionCall( | ||||||
|  |                                           name=function_name).model_dump( | ||||||
|  |                                               exclude_none=True)) | ||||||
|  |                     ]) | ||||||
|  |                     self.current_tool_name_sent = True | ||||||
|  |  | ||||||
|  |             # now we know we're on the same tool call and we're streaming | ||||||
|  |             # arguments | ||||||
|  |             else: | ||||||
|  |                 cur_arguments = current_tool_call.get("arguments") | ||||||
|  |  | ||||||
|  |                 if cur_arguments: | ||||||
|  |                     sent = len( | ||||||
|  |                         self.streamed_args_for_tool[self.current_tool_id]) | ||||||
|  |                     cur_args_json = json.dumps(cur_arguments) | ||||||
|  |                     prev_arguments = self.prev_tool_call_arr[ | ||||||
|  |                         self.current_tool_id].get("arguments") | ||||||
|  |  | ||||||
|  |                     argument_diff = None | ||||||
|  |                     if is_complete[self.current_tool_id]: | ||||||
|  |                         argument_diff = cur_args_json[sent:] | ||||||
|  |                     elif prev_arguments: | ||||||
|  |                         prev_args_json = json.dumps(prev_arguments) | ||||||
|  |                         if cur_args_json != prev_args_json: | ||||||
|  |                             prefix = find_common_prefix( | ||||||
|  |                                 prev_args_json, cur_args_json) | ||||||
|  |                             argument_diff = prefix[sent:] | ||||||
|  |  | ||||||
|  |                     if argument_diff is not None: | ||||||
|  |                         delta = DeltaMessage(tool_calls=[ | ||||||
|  |                             DeltaToolCall(index=self.current_tool_id, | ||||||
|  |                                           function=DeltaFunctionCall( | ||||||
|  |                                               arguments=argument_diff). | ||||||
|  |                                           model_dump(exclude_none=True)) | ||||||
|  |                         ]) | ||||||
|  |                         self.streamed_args_for_tool[ | ||||||
|  |                             self.current_tool_id] += argument_diff | ||||||
|  |  | ||||||
|  |             self.prev_tool_call_arr = tool_call_arr | ||||||
|  |             return delta | ||||||
|  |  | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error("Error trying to handle streaming tool call: %s", e) | ||||||
|  |             logger.debug( | ||||||
|  |                 "Skipping chunk as a result of tool streaming extraction " | ||||||
|  |                 "error") | ||||||
|  |             return None | ||||||
		Reference in New Issue
	
	Block a user