Compare commits

...

118 Commits

Author SHA1 Message Date
ff44fe848d Make sure remote tool can be loaded 2023-05-08 15:50:00 +00:00
fe909e6650 Quality 2023-05-07 20:36:02 -04:00
13cc6e534f Merge branch 'test_composition' of github.com:huggingface/transformers into test_composition 2023-05-07 20:23:11 -04:00
9be1243a7d Remote tools can be loaded 2023-05-07 20:23:02 -04:00
c2df74b98d Text summarization tests 2023-05-07 23:07:15 +00:00
a9c92a7a4a Tests 2023-05-07 23:06:39 +00:00
4258f453fd Avoid redundancy between global variables 2023-05-07 17:50:30 -04:00
e84bcc7997 return_code adds tool creation 2023-05-07 17:33:55 -04:00
4157b35304 Make a requirements 2023-05-06 17:21:18 -04:00
5bdc6c96ec old remote tool and new remote tool 2023-05-06 14:38:09 -04:00
1d574aef29 Clean up 2023-05-06 13:50:08 -04:00
1bd7848d9e Custom inference API for endpoints too 2023-05-06 13:49:11 -04:00
9d706e8342 Clean push_to_hub and add app file 2023-05-06 12:51:38 -04:00
7a9137b3bf Tool save/from_hub/push_to_hub and tool->load_tool 2023-05-06 12:13:59 -04:00
4d991ed2ba Fix tests 2023-05-05 23:21:38 +00:00
69c1abbc55 Tests 2023-05-05 22:48:21 +00:00
43bddabbb4 Fix init 2023-05-05 16:57:38 -04:00
a873db85e7 save_pretrained and push_to_hub for tool 2023-05-05 16:54:50 -04:00
2711a393d9 Tool clean up 2023-05-05 16:16:04 -04:00
2234ff1126 Custom tools, custom prompt 2023-05-05 15:49:29 -04:00
2be32455ee Fixes in chat prompt 2023-05-05 14:01:34 -04:00
3304f1714c Update description of the tool 2023-05-05 13:57:54 -04:00
ac48681ad0 Fix link displayed 2023-05-05 13:44:44 -04:00
749a041984 Change summarization model (#23172) 2023-05-05 12:17:08 -04:00
6293e28d16 Small fixes 2023-05-05 11:49:13 +00:00
0f1cf429b1 correct some spelling 2023-05-05 11:05:53 +02:00
dc6743db19 Evaluation for chat agents 2023-05-04 21:24:56 -04:00
e9b166e8b6 Merge branch 'test_composition' of github.com:huggingface/transformers into test_composition 2023-05-04 16:58:16 -04:00
7f26a004a1 Use last result/assign for evaluation 2023-05-04 16:58:08 -04:00
aedbcd2469 Remove hardcoded selection 2023-05-04 20:56:06 +00:00
a740e1a3a4 Prompt 2023-05-04 20:53:42 +00:00
4ec7077510 Fix everything 2023-05-04 15:18:34 -04:00
47040aec29 tool 2023-05-04 19:06:41 +00:00
3d141d02c0 Merge branch 'test_composition' of github.com:huggingface/transformers into test_composition 2023-05-04 15:05:38 -04:00
4cafdb1098 Changes 2023-05-04 15:05:35 -04:00
fea2e6ce65 Tools 2023-05-04 19:03:01 +00:00
be0a949ad6 Tools 2023-05-04 18:51:11 +00:00
1d00d759f9 Clean up eval 2023-05-04 14:44:38 -04:00
d4fa69d995 Work 2023-05-04 11:52:14 -04:00
b317a33b6c Use remote tools descriptions 2023-05-04 10:41:26 -04:00
1239cdd59c Really do it 2023-05-03 16:45:22 -04:00
33903dad2f Remove nestedness in tool config 2023-05-03 16:43:33 -04:00
85e89c890b Add method to reset state 2023-05-03 16:34:49 -04:00
c7609790a7 New format for tools on the Hub 2023-05-03 16:04:13 -04:00
325d60ba9c Fix evaluation for agents 2023-05-03 16:03:13 -04:00
ed376943bb Merge branch 'test_composition' of github.com:huggingface/transformers into test_composition 2023-05-03 15:57:01 -04:00
0d65be6720 Cache agents and clean up 2023-05-03 15:56:54 -04:00
ff3240b93b Blank init 2023-05-03 19:37:43 +00:00
f1dc3b3da1 Merge branch 'test_composition' of github.com:huggingface/transformers into test_composition 2023-05-03 15:18:00 -04:00
ffadafb9f9 Let's chat! 2023-05-03 15:17:54 -04:00
51561d9c13 Temporary bs4 safeguard 2023-05-03 19:10:46 +00:00
a61d5b66ee Cleanup 2023-05-03 18:54:55 +00:00
b08e4af920 Tools on the Hub 2023-05-03 11:40:57 -04:00
e4e95f1cf3 Build prompt with tools descriptions 2023-05-02 16:52:27 -04:00
48ccb1aabe Add more tools 2023-05-02 19:36:19 +00:00
7f4aaf2f7e New evaluation 2023-05-02 15:34:09 -04:00
2906d2509b Add tools 2023-05-02 18:40:56 +00:00
cef1b96ea1 Harmonize 2023-05-02 13:38:12 -04:00
a255eaa3d0 Add to big prompt 2023-05-02 13:34:49 -04:00
895e472484 Add problems 2023-05-02 13:25:53 -04:00
8ad381efb7 Typo 2023-05-02 11:52:32 -04:00
6df0c15033 Make all tools a dict variable 2023-05-02 11:47:44 -04:00
d8abc32ef6 New endpoints agents 2023-05-02 11:12:22 -04:00
85e3acd1d6 Evaluate new agents 2023-05-01 21:23:34 -04:00
5589cec4c7 New version of the agent 2023-05-01 17:00:52 -04:00
6b64832bbb Add prompts 2023-05-01 14:54:34 -04:00
34801f1e67 Be consistent 2023-05-01 14:43:32 -04:00
afbed38879 Remove dict for translation 2023-05-01 14:38:23 -04:00
0c3fd13884 Back to one prompt 2023-05-01 14:15:09 -04:00
290e8306c4 More problems, add python primitives 2023-05-01 13:18:22 -04:00
28fae8b6fa Make problems easier - interface to debug 2023-05-01 11:50:15 -04:00
8b2221a3aa Big refactor of descriptions, batch generation and evaluation for agents 2023-04-28 16:04:30 -04:00
e819766fdf Style post-rebase 2023-04-27 15:35:19 -04:00
74c98e5fef Tool on the Hub 2023-04-27 15:34:52 -04:00
731029b03b Add tool wrapper 2023-04-27 15:34:52 -04:00
8c7c78a8f2 Fix args eval in interpreter 2023-04-27 15:34:51 -04:00
c70005aa11 Better prompts 2023-04-27 15:34:51 -04:00
0e80e1f75c Clean description 2023-04-27 15:34:51 -04:00
2868ea4c9d Male Basic optional in token 2023-04-27 15:34:50 -04:00
f85f4c6b0e No randomness 2023-04-27 15:34:50 -04:00
38d71024ab Remove accelerate and try to be reproducible 2023-04-27 15:34:50 -04:00
37f6a02f1e Style 2023-04-27 15:34:49 -04:00
d88b770d9a Cleanup 2023-04-27 15:34:49 -04:00
44bea11f6f ControlNet description 2023-04-27 15:34:48 -04:00
d9caa406c2 Lib protection 2023-04-27 15:34:48 -04:00
c136c19712 Diffusers protection 2023-04-27 15:34:01 -04:00
8b1aa74904 Gradio demo 2023-04-27 15:34:01 -04:00
42e16b3761 ControlNet 2023-04-27 15:34:01 -04:00
b4ebe76fea Image segmentation 2023-04-27 15:34:00 -04:00
cc187725b7 Add option to return code and update doc 2023-04-27 15:34:00 -04:00
4cd8879812 More remote tools 2023-04-27 15:34:00 -04:00
739b7c0d0e Clean up remote tools 2023-04-27 15:33:59 -04:00
1bf544a036 Description 2023-04-27 15:33:59 -04:00
3d673f0602 SD 2023-04-27 15:33:58 -04:00
a463897305 One prompt to rule them all. 2023-04-27 15:33:58 -04:00
49766e665d Make sure everyone has a default 2023-04-27 15:33:58 -04:00
360cddd78c Unwanted change 2023-04-27 15:33:57 -04:00
9cef2daeae Update prompt 2023-04-27 15:33:57 -04:00
259b0af144 Fixes 2023-04-27 15:33:57 -04:00
6c1f823c83 Some rename + README 2023-04-27 15:33:56 -04:00
77225aa758 Deal with typos + example of inference API 2023-04-27 15:33:56 -04:00
bd2a928909 Add setup 2023-04-27 15:33:55 -04:00
d6a36c0343 Support errors and rename OpenAssistantAgent 2023-04-27 15:33:55 -04:00
c8c766b566 Refactor descriptions and remove chain 2023-04-27 15:33:55 -04:00
7fa6db98a2 Style 2023-04-27 15:33:54 -04:00
0d4600e77d captioning + s2t fixes 2023-04-27 15:33:54 -04:00
fd140c0724 Add open assistance, support f-strings in evaluate 2023-04-27 15:33:53 -04:00
e9e68e856d Quality + word missing in translation 2023-04-27 15:33:53 -04:00
e1e47d67f1 GenQA + LID + S2T 2023-04-27 15:33:53 -04:00
9f3da1d6aa temp 2023-04-27 15:33:52 -04:00
b8354f845b Add translation tool 2023-04-27 15:33:52 -04:00
67f63d0f91 Quality 2023-04-27 15:33:51 -04:00
3149597939 Add agents 2023-04-27 15:33:51 -04:00
98c15ff46f Basic python interpreter 2023-04-27 15:33:51 -04:00
e00438ffa7 Rename 2023-04-27 15:33:50 -04:00
8d066a02e8 J'ai pris des libertés 2023-04-27 15:33:50 -04:00
30221a8d5f Text to speech 2023-04-27 15:33:50 -04:00
4f9b256b31 PoC for some chaining API 2023-04-27 15:33:49 -04:00
37 changed files with 3338 additions and 8 deletions

View File

@ -608,6 +608,7 @@ _import_structure = {
"SpecialTokensMixin",
"TokenSpan",
],
"tools": [],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",

View File

@ -115,9 +115,9 @@ def get_relative_import_files(module_file):
return all_relative_imports
def check_imports(filename):
def get_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
Extracts all the libraries that are imported in a file.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
@ -131,9 +131,14 @@ def check_imports(filename):
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports))
# Unique-ify and test we got them all
imports = list(set(imports))
def check_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
"""
imports = get_imports(filename)
missing_packages = []
for imp in imports:
try:
@ -169,6 +174,7 @@ def get_cached_module_file(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
_commit_hash: Optional[str] = None,
):
"""
@ -207,6 +213,8 @@ def get_cached_module_file(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -229,7 +237,7 @@ def get_cached_module_file(
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
cached_module = try_to_load_from_cache(
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
new_files = []
@ -245,6 +253,7 @@ def get_cached_module_file(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
repo_type=repo_type,
_commit_hash=_commit_hash,
)
if not is_local and cached_module != resolved_module_file:
@ -309,8 +318,10 @@ def get_cached_module_file(
if len(new_files) > 0:
new_files = "\n".join([f"- {f}" for f in new_files])
repo_type_str = "" if repo_type is None else f"{repo_type}/"
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
logger.warning(
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
f"A new version of the following files was downloaded from {url}:\n{new_files}"
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
"versions of the code file, you can pin a revision."
)
@ -328,6 +339,7 @@ def get_class_from_dynamic_module(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
**kwargs,
):
"""
@ -377,6 +389,8 @@ def get_class_from_dynamic_module(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -418,6 +432,7 @@ def get_class_from_dynamic_module(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
repo_type=repo_type,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
@ -478,12 +493,17 @@ def custom_object_save(obj, folder, config=None):
elif config is not None:
_set_auto_map_in_config(config)
result = []
# Copy module file to the output folder.
object_file = sys.modules[obj.__module__].__file__
dest_file = Path(folder) / (Path(object_file).name)
shutil.copy(object_file, dest_file)
result.append(dest_file)
# Gather all relative imports recursively and make sure they are copied as well.
for needed_file in get_relative_import_files(object_file):
dest_file = Path(folder) / (Path(needed_file).name)
shutil.copy(needed_file, dest_file)
result.append(dest_file)
return result

View File

@ -64,6 +64,10 @@ class ChannelDimension(ExplicitEnum):
LAST = "channels_last"
def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
def is_valid_image(img):
return (
(is_vision_available() and isinstance(img, PIL.Image.Image))

View File

@ -0,0 +1,141 @@
# Do anything with Transformers
Transformers support all modalities and has many models performing many different types of tasks. But it can get confusing to mix and match them to solve the problem at hand, which is why we have developed a new API of **tools** and **agents**. Given a prompt in natural language and a set of tools, an agent will determine the right code to run with the tools and chain them properly to give you the result you expected.
Let's start with examples!
## Examples
First we need an agent, which is a fancy word to design a LLM tasked with writing the code you will need. We support the traditional openai LLMs but you should really try the opensource alternatives developed by the community which:
- clearly state the data they have been trained on
- you can run on your own cloud or hardware
- have built-in versioning
<!--TODO for the release we should have a publicly available agent and if token is none, we grab the HF token-->
```py
from transformers.tools import EndpointAgent
agent = EndpointAgent(
url_endpoint=your_endpoint,
token=your_hf_token,
)
# from transformers.tools import OpenAiAgent
# agent = OpenAiAgent(api_key=your_openai_api_key)
```
### Task 1: Classifying text in (almost) any language
Now to execute a given task, we need to pick a set of tools in `transformers` and send them to our agent. Let's say you want to classify a text in a non-English language, and you have trouble finding a model trained in that language. You can pick a translation tool and a standard text classification tool:
```py
from transformers.tools import TextClassificationTool, TranslationTool
tools = [TextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
```
then you just run this by your agent:
```py
agent.run(
"Determine if the following `text` (in French) is positive or negative.",
tools=tools,
text="J'aime beaucoup Hugging Face!"
)
```
Note that you can send any additional inputs in a variable that you named in your prompt (between backticks because it helps the LLM). For text inputs, you can just put them in the prompt:
```py
agent.run(
"""Determine if the following text: "J'aime beaucoup Hugging Face!" (in French) is positive or negative.""",
tools=tools,
)
```
In both cases, you should see the agent generate code using your set of tools that is then executed to provide you the answer you were looking for. Neat!
If you don't have the hardware to run the models translating and classifying the text, you can use the inference API by selecting a remote tool:
```py
from transformers.tools import RemoteTextClassificationTool, TranslationTool
tools = [RemoteTextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
agent.run(
"Determine if the following `text` (in French) is positive or negative.",
tools=tools,
text="J'aime beaucoup Hugging Face!"
)
```
This was still all text-based. Let's now get to something more exciting, combining vision and speech
## Example 2:
Let's say we want to hear out loud what is in a given image. There are models that do image-captioning in Transformers, and other models that generate speech from text, but how to combine them? Quite easily:
<!--TODO add the audio reader tool once it exists-->
```py
import requests
from PIL import Image
from transformers.tools import ImageCaptioningTool, TextToSpeechTool
tools = [ImageCaptioningTool(), TextToSpeechTool()]
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
speech = agent.run(
"Tell me out loud what the `image` contains.",
tools=tools,
image=image
)
```
Note that here you have to pass your input as a separate variable since you can't really embed your image in the text.
In all those examples, we have been using the default checkpoint for a given tool, but you can specify the one you want! For instance, the image-captioning tool uses BLIP by default, but let's upgrade to BLIP-2
<!--TODO Once it works, use the inference API for BLIP-2 here as it's heavy-->
```py
tools = [ImageCaptioningTool("Salesforce/blip2-opt-2.7b"), TextToSpeechTool()]
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
speech = agent.run(
"Tell me out loud what the `image` contains.",
tools=tools,
image=image
)
```
Add more examples?
## How does it work ?
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the LLM to give a small sample of code performing a task with a set of tools. This prompt is then completed by the task you give your agent and the description of the tools you give it. This way it gets access to the doc of the tools you are using, especially their expected inputs and outputs and can generate the relevant code.
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools. Pipelines are more refactored and often combine several tasks in one. Tools are really meant to be focused one very simple task only.
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools. I hear you screaming "Arbitrary code execution!" in the back, but calm down a minute and let me explain.
The only functions that can be called are the tools you provided and the print function, so you're already limited in what can be executed. You should be safe if it's limited to Hugging Face tools. Then we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the `run()` method with the additional argument `return_code=True`, in which case the agent will just return the code to execute and you can decide whether to do it or not.
Note that LLMs are still not *that* good at producing the small amount of code to chain the tools, so we added some logic to fix typos during the evaluation: there are often misnamed variable names or dictionary keys.
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
## Future developments
We hope you're as excited by this new API as we are. Here are a few things we are thinking of adding next if we see the community is interested:
- Make the agent pick the tools itself in a first step.
- Make the run command more chat-based, so you can copy-paste any error message you see in a next step to have the LLM fix its code, or ask for some improvements.
- Add support for more type of agents

View File

@ -0,0 +1,13 @@
from .agents import Agent, EndpointAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, load_tool
from .document_question_answering import DocumentQuestionAnsweringTool
from .generative_question_answering import GenerativeQuestionAnsweringTool, RemoteGenerativeQuestionAnsweringTool
from .image_captioning import ImageCaptioningTool, RemoteImageCaptioningTool
from .image_question_answering import ImageQuestionAnsweringTool
from .image_segmentation import ImageSegmentationTool
from .language_identifier import LanguageIdentificationTool
from .speech_to_text import RemoteSpeechToTextTool, SpeechToTextTool
from .text_classification import RemoteTextClassificationTool, TextClassificationTool
from .text_summarization import RemoteTextSummarizationTool, TextSummarizationTool
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool

View File

@ -0,0 +1,382 @@
import importlib.util
import json
import os
import time
from dataclasses import dataclass
import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
from ..utils import logging
from .base import TASK_MAPPING, Tool, load_tool
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
from .python_interpreter import evaluate
logger = logging.get_logger(__name__)
# Move to util when this branch is ready to merge
def is_openai_available():
return importlib.util.find_spec("openai") is not None
if is_openai_available():
import openai
_tools_are_initialized = False
BASE_PYTHON_TOOLS = {
"print": print,
"float": float,
"int": int,
"bool": bool,
"str": str,
}
@dataclass
class PreTool:
task: str
description: str
repo_id: str
HUGGINGFACE_DEFAULT_TOOLS = {}
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
"image-transformation",
"text-download",
"text-to-image",
"text-to-video",
"image-inpainting",
]
def get_remote_tools(organization="huggingface-tools"):
spaces = list_spaces(author=organization)
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, "tool_config.json", repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
for task, task_info in config.items():
tools[task_info["name"]] = PreTool(task=task, description=task_info["description"], repo_id=repo_id)
return tools
def _setup_default_tools():
global HUGGINGFACE_DEFAULT_TOOLS
global _tools_are_initialized
if _tools_are_initialized:
return
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
remote_tools = get_remote_tools()
for task_name in TASK_MAPPING:
tool_class_name = TASK_MAPPING.get(task_name)
tool_class = getattr(tools_module, tool_class_name)
description = tool_class.description
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
found = False
for tool_name, tool in remote_tools.items():
if tool.task == task_name:
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
found = True
break
if not found:
raise ValueError(f"{task_name} is not implemented on the Hub.")
_tools_are_initialized = True
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
if cached_tools is None:
resolved_tools = BASE_PYTHON_TOOLS.copy()
else:
resolved_tools = cached_tools
for name, tool in toolbox.items():
if name not in code or name in resolved_tools:
continue
if isinstance(tool, Tool):
resolved_tools[name] = tool
else:
resolved_tools[name] = load_tool(tool.task, repo_id=tool.repo_id, remote=remote)
return resolved_tools
def get_tool_creation_code(code, toolbox, remote=False):
code_lines = ["from transformers import load_tool", ""]
for name, tool in toolbox.items():
if name not in code or isinstance(tool, Tool):
continue
line = f'{name} = load_tool("{tool.task}"'
if tool.repo_id is not None:
line += f', repo_id="{tool.repo_id}"'
if remote:
line += ", remote=True)"
line += ")"
code_lines.append(line)
return "\n".join(code_lines) + "\n"
def clean_code_for_chat(result):
lines = result.split("\n")
idx = 0
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
idx += 1
explanation = "\n".join(lines[:idx]).strip()
if idx == len(lines):
return explanation, None
idx += 1
start_idx = idx
while not lines[idx].lstrip().startswith("```"):
idx += 1
code = "\n".join(lines[start_idx:idx]).strip()
return explanation, code
def clean_code_for_run(result):
result = f"I will use the following {result}"
explanation, code = result.split("Answer:")
explanation = explanation.strip()
code = code.strip()
code_lines = code.split("\n")
if code_lines[0] in ["```", "```py"]:
code_lines = code_lines[1:]
if code_lines[-1] == "```":
code_lines = code_lines[:-1]
code = "\n".join(code_lines)
return explanation, code
class Agent:
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
_setup_default_tools()
self.chat_prompt_template = CHAT_MESSAGE_PROMPT if chat_prompt_template is None else chat_prompt_template
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
self.toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
if additional_tools is not None:
if isinstance(additional_tools, (list, tuple)):
additional_tools = {t.name: t for t in additional_tools}
elif not isinstance(additional_tools, dict):
additional_tools = {additional_tools.name: additional_tools}
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
self.toolbox.update(additional_tools)
if len(replacements) > 1:
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
logger.warn(
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
)
elif len(replacements) == 1:
name = list(replacements.keys())[0]
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
self.prepare_for_new_chat()
def format_prompt(self, task, chat_mode=False):
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
if chat_mode:
if self.chat_history is None:
prompt = CHAT_PROMPT_TEMPLATE.replace("<<all_tools>>", description)
else:
prompt = self.chat_history
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
else:
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
prompt = prompt.replace("<<prompt>>", task)
return prompt
def chat(self, task, return_code=False, remote=False, **kwargs):
prompt = self.format_prompt(task, chat_mode=True)
result = self._generate_one(prompt, stop=["Human:", "====="])
self.chat_history = prompt + result + "\n"
explanation, code = clean_code_for_chat(result)
print(f"==Explanation from the agent==\n{explanation}")
if code is not None:
print(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
self.chat_state.update(kwargs)
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
def prepare_for_new_chat(self):
self.chat_history = None
self.chat_state = {}
self.cached_tools = None
def run(self, task, return_code=False, remote=False, **kwargs):
prompt = self.format_prompt(task)
result = self._generate_one(prompt, stop=["Task:"])
explanation, code = clean_code_for_run(result)
print(f"==Explanation from the agent==\n{explanation}")
print(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy())
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
class OpenAiAgent(Agent):
"""
Example:
```py
from transformers.tools.agents import NewOpenAiAgent
agent = NewOpenAiAgent(model="text-davinci-003", api_key=xxx)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self,
model="gpt-3.5-turbo",
api_key=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAIAgent` requires `openai`: `pip install openai`.")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
"xxx."
)
else:
openai.api_key = api_key
self.model = model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_code(self, task):
is_batched = isinstance(task, list)
if is_batched:
prompts = [self.format_prompt(one_task) for one_task in task]
else:
prompts = [self.format_prompt(task)]
if "gpt" in self.model:
results = [self._chat_generate(prompt, stop="Task:") for prompt in prompts]
else:
results = self._completion_generate(prompts, stop="Task:")
return results if is_batched else results[0]
def _generate_one(self, prompt, stop):
if "gpt" in self.model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]
def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]
def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
model=self.model,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]
class EndpointAgent(Agent):
def __init__(
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
):
self.url_endpoint = url_endpoint
if token is None:
self.token = f"Bearer {HfFolder().get_token()}"
elif token.startswith("Bearer") or token.startswith("Basic"):
self.token = token
else:
self.token = f"Bearer {token}"
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_code(self, task):
is_batched = isinstance(task, list)
if is_batched:
prompts = [self.format_prompt(one_task) for one_task in task]
else:
prompts = [self.format_prompt(task)]
# Can probably batch those but can't test anymore right now as the endpoint has been limited in length.
results = [self._generate_one(prompt) for prompt in prompts]
return results if is_batched else results[0]
def _generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
inputs = {
"inputs": prompt,
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Error {response.status_code}: {response.json()}")
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result

View File

@ -0,0 +1,548 @@
import base64
import importlib
import io
import json
import os
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import CommitOperationAdd, HfFolder, InferenceApi, create_commit, create_repo, hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError, get_session
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
from ..models.auto import AutoProcessor
from ..utils import (
CONFIG_NAME,
cached_file,
is_accelerate_available,
is_torch_available,
is_vision_available,
logging,
working_or_temp_dir,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate.utils import send_to_device
TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None:
return repo_type
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
return "space"
except RepositoryNotFoundError:
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
except Exception:
return "model"
except Exception:
return "space"
APP_FILE_TEMPLATE = """from transformers.tools.base import launch_gradio_demo from {module_name} import {class_name}
launch_gradio_demo({class_name})
"""
class Tool:
"""
Example of a super 'Tool' class that could live in huggingface_hub
"""
description: str = "This is a tool that ..."
name: str = ""
inputs: List[str]
outputs: List[str]
def __init__(self, *args, **kwargs):
self.is_initialized = False
pass
def __call__(self, *args, **kwargs): # Might become run?
return NotImplemented("Write this method in your subclass of `Tool`.")
def setup(self):
# Do here any operation that is expensive and needs to be executed before you start using your tool. Such as
# loading a big model.
self.is_initialized = True
def save(self, output_dir, task_name=None):
os.makedirs(output_dir, exist_ok=True)
# Save module file
module_files = custom_object_save(self, output_dir)
module_name = self.__class__.__module__
last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{self.__class__.__name__}"
# Save config file
config_file = os.path.join(output_dir, "tool_config.json")
if os.path.isfile(config_file):
with open(config_file, "r", encoding="utf-8") as f:
tool_config = json.load(f)
else:
tool_config = {}
if task_name is None:
class_name = self.__class__.__name__.replace("Tool", "")
chars = [f"_{c.lower()}" if c.isupper() else c for c in class_name]
task_name = "".join(chars)[1:]
tool_config[task_name] = {"tool_class": full_name, "description": self.description, "name": self.name}
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
imports = []
for module in module_files:
imports.extend(get_imports(module))
imports = list(set(imports))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n")
@classmethod
def from_hub(cls, task_or_repo_id, repo_id=None, model_repo_id=None, token=None, remote=False, **kwargs):
if remote and model_repo_id is None:
raise ValueError("To use this tool remotely, please pass along the url endpoint to `model_repo_id`.")
hub_kwargs_names = [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"repo_type",
"subfolder",
"local_files_only",
]
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
if repo_id is None:
repo_id = task_or_repo_id
task = None
else:
task = task_or_repo_id
# Try to get the tool config first.
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
resolved_config_file = cached_file(
repo_id,
TOOL_CONFIG_FILE,
use_auth_token=token,
**hub_kwargs,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
is_tool_config = resolved_config_file is not None
if resolved_config_file is None:
resolved_config_file = cached_file(
repo_id,
CONFIG_NAME,
use_auth_token=token,
**hub_kwargs,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
raise EnvironmentError(
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
)
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
if not is_tool_config:
if "custom_tools" not in config:
raise EnvironmentError(
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
)
custom_tools = config["custom_tools"]
else:
custom_tools = config
if task is None:
if len(custom_tools) == 1:
task = list(custom_tools.keys())[0]
else:
tasks_available = "\n".join([f"- {t}" for t in custom_tools.keys()])
raise ValueError(f"Please select a task among the one available in {repo_id}:\n{tasks_available}")
tool_class = custom_tools[task]["tool_class"]
tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs)
if model_repo_id is not None:
repo_id = model_repo_id
elif hub_kwargs["repo_type"] == "space":
repo_id = None
if remote:
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
return tool_class(repo_id, token=token, **kwargs)
def push_to_hub(
self,
repo_id: str,
use_temp_dir: Optional[bool] = None,
commit_message: str = "Upload tool",
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
) -> str:
"""
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
`repo_path_or_name`.
Parameters:
repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when
pushing to a given organization.
use_temp_dir (`bool`, *optional*):
Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
commit_message (`str`, *optional*, defaults to `"Upload too"`):
Message to commit while pushing.
private (`bool`, *optional*):
Whether or not the repository created should be private.
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unsel, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
"""
if os.path.isdir(repo_id):
working_dir = repo_id
repo_id = repo_id.split(os.path.sep)[-1]
else:
working_dir = repo_id.split("/")[-1]
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
)
repo_id = repo_url.repo_id
if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)
with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
files_timestamps = self._get_files_timestamps(work_dir)
# Save all files.
self.save(work_dir)
modified_files = [
f
for f in os.listdir(work_dir)
if f not in files_timestamps or os.path.getmtime(os.path.join(work_dir, f)) > files_timestamps[f]
]
operations = []
for file in modified_files:
operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, file), path_in_repo=file))
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
return create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
token=token,
create_pr=create_pr,
repo_type="space",
)
class OldRemoteTool(Tool):
default_checkpoint = None
def __init__(self, repo_id=None, token=None):
if repo_id is None:
repo_id = self.default_checkpoint
self.repo_id = repo_id
self.client = InferenceApi(repo_id, token=token)
def prepare_inputs(self, *args, **kwargs):
if len(args) > 1:
raise ValueError("A `RemoteTool` can only accept one positional input.")
elif len(args) == 1:
return {"data": args[0]}
return {"inputs": kwargs}
def extract_outputs(self, outputs):
return outputs
def __call__(self, *args, **kwargs):
if not self.is_initialized:
self.setup()
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
outputs = self.client(**inputs)
else:
outputs = self.client(inputs)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
return self.extract_outputs(outputs)
class RemoteTool(Tool):
default_url = None
def __init__(self, endpoint_url=None, token=None, tool_class=None):
if endpoint_url is None:
endpoint_url = self.default_url
self.endpoint_url = endpoint_url
self.client = EndpointClient(endpoint_url, token=token)
self.tool_class = tool_class
def prepare_inputs(self, *args, **kwargs):
if len(args) > 1:
raise ValueError("A `RemoteTool` can only accept one positional input.")
elif len(args) == 1:
if is_pil_image(args[0]):
byte_io = io.BytesIO()
args[0].save(byte_io, format="PNG")
return {"inputs": byte_io.getvalue()}
return {"inputs": args[0]}
inputs = kwargs.copy()
for key, value in inputs.items():
if is_pil_image(value):
byte_io = io.BytesIO()
value.save(byte_io, format="PNG")
inputs[key] = byte_io.getvalue()
return {"inputs": kwargs}
def extract_outputs(self, outputs):
return outputs
def __call__(self, *args, **kwargs):
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
outputs = self.client(**inputs, output_image=output_image)
else:
outputs = self.client(inputs, output_image=output_image)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
return self.extract_outputs(outputs)
class PipelineTool(Tool):
pre_processor_class = AutoProcessor
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
def __init__(
self,
model=None,
pre_processor=None,
post_processor=None,
device=None,
device_map=None,
model_kwargs=None,
token=None,
**hub_kwargs,
):
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if not is_accelerate_available():
raise ImportError("Please install accelerate in order to use this tool.")
if model is None:
if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
self.model = model
self.pre_processor = pre_processor
self.post_processor = post_processor
self.device = device
self.device_map = device_map
self.model_kwargs = {} if model_kwargs is None else model_kwargs
if device_map is not None:
self.model_kwargs["device_map"] = device_map
self.hub_kwargs = hub_kwargs
self.hub_kwargs["use_auth_token"] = token
self.is_initialized = False
def setup(self):
# Instantiate me maybe
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
if self.device_map is None:
self.model.to(self.device)
def post_init(self):
pass
def encode(self, raw_inputs):
return self.pre_processor(raw_inputs)
def forward(self, inputs):
return self.model(**inputs)
def decode(self, outputs):
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
if not self.is_initialized:
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
encoded_inputs = send_to_device(encoded_inputs, self.device)
outputs = self.forward(encoded_inputs)
outputs = send_to_device(outputs, "cpu")
return self.decode(outputs)
def launch_gradio_demo(tool_class: Tool):
try:
import gradio as gr
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
tool = tool_class()
def fn(*args, **kwargs):
return tool(*args, **kwargs)
gr.Interface(
fn=fn,
inputs=tool_class.inputs,
outputs=tool_class.outputs,
title=tool_class.__name__,
article=tool.description,
).launch()
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
def get_default_device():
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
TASK_MAPPING = {
"generative-qa": "GenerativeQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
"image-segmentation": "ImageSegmentationTool",
# "language-identification": "LanguageIdentificationTool",
"speech-to-text": "SpeechToTextTool",
"text-classification": "TextClassificationTool",
"text-to-speech": "TextToSpeechTool",
"translation": "TranslationTool",
"summarization": "TextSummarizationTool",
"image-question-answering": "ImageQuestionAnsweringTool",
"document-question-answering": "DocumentQuestionAnsweringTool",
}
def load_tool(task_or_repo_id, repo_id=None, remote=False, token=None, **kwargs):
if task_or_repo_id in TASK_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
tool_class = getattr(tools_module, tool_class_name)
if remote:
return RemoteTool(repo_id, token=token, tool_class=tool_class)
else:
return tool_class(repo_id, token=token, **kwargs)
else:
return Tool.from_hub(task_or_repo_id, repo_id=repo_id, token=token, remote=remote, **kwargs)
def add_description(description):
"""
A decorator that adds a description to a function.
"""
def inner(func):
func.description = description
return func
return inner
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
if token is None:
token = HfFolder().get_token()
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
self.endpoint_url = endpoint_url
def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
output_image: bool = False,
) -> Any:
# Build payload
payload = {}
if inputs:
payload["inputs"] = inputs
if params:
payload["parameters"] = params
# Make API call
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
# By default, parse the response for the user.
if output_image:
if not is_vision_available():
raise ImportError(
f"Task '{self.task}' returned as image but Pillow is not installed."
" Please install it (`pip install Pillow`) or pass"
" `raw_response=True` to get the raw `Response` object and parse"
" the image by yourself."
)
from PIL import Image
return Image.open(io.BytesIO(base64.b64decode(response.content)))
else:
return response.json()

View File

@ -0,0 +1,69 @@
import re
from transformers import VisionEncoderDecoderModel
from ..models.auto import AutoProcessor
from ..utils import is_vision_available
from .base import PipelineTool
if is_vision_available():
from PIL import Image
DOCUMENT_QUESTION_ANSWERING_DESCRIPTION = (
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which should be the "
"document containing the information, as well as a `question` that is the question about the document. It returns a text "
"that contains the answer to the question."
)
class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = DOCUMENT_QUESTION_ANSWERING_DESCRIPTION
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
super().__init__(*args, **kwargs)
def encode(self, image: "Image", question: str):
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = self.pre_processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = self.pre_processor(image, return_tensors="pt").pixel_values
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
def forward(self, inputs):
return self.model.generate(
inputs["pixel_values"].to(self.device),
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
max_length=self.model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
).sequences
def decode(self, outputs):
sequence = self.pre_processor.batch_decode(outputs)[0]
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "").replace(
self.pre_processor.tokenizer.pad_token, ""
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
sequence = self.pre_processor.token2json(sequence)
return sequence["answer"]

View File

@ -0,0 +1,686 @@
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run
from .python_interpreter import InterpretorError, evaluate
### Fake tools for test
def classifier(text, labels):
return f"This is the classification of {text} along {labels}."
def translator(text, src_lang, tgt_lang):
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
def speaker(text):
return f"This is actually a sound reading {text}."
def transcriber(audio):
if "sound" not in audio:
raise ValueError(f"`audio` ({audio}) is not a sound.")
return f"This is the transcribed text from {audio}."
def image_generator(prompt):
return f"This is actually an image representing {prompt}."
def image_captioner(image):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a description of {image}."
def image_transformer(image, prompt):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a transformation of {image} according to {prompt}."
def question_answerer(text, question):
return f"This is the answer to {question} from {text}."
def image_qa(image, question):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is the answer to {question} from {image}."
def text_downloader(url):
return f"This is the content of {url}."
def summarizer(text):
return f"This is a summary of {text}."
def video_generator(prompt, seconds=2):
return f"A video of {prompt}"
def document_qa(image, question):
return f"This is the answer to {question} from the document {image}."
def image_segmenter(image, prompt):
return f"This is the mask of {prompt} in {image}"
def image_inpainter(image, mask, prompt):
return f"This is the inpainted of {image} using prompt {prompt} and mask {mask}"
TEST_TOOLS = {
"text_classifier": classifier,
"translator": translator,
"text_reader": speaker,
"summarizer": summarizer,
"transcriber": transcriber,
"image_generator": image_generator,
"image_captioner": image_captioner,
"image_transformer": image_transformer,
"text_qa": question_answerer,
"text_downloader": text_downloader,
"image_qa": image_qa,
"video_generator": video_generator,
"document_qa": document_qa,
"image_segmenter": image_segmenter,
"image_inpainter": image_inpainter,
}
class Problem:
"""
A class regrouping all the information to solve a problem on which we will evaluate agents.
Args:
task (`str` ou `list[str]`):
One or several descriptions of the task to perform. If a list, it should contain variations on the
phrasing, but for the same task.
inputs (`list[str]` or `dict[str, str]`):
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
inputs expected (the value used will be `<<input_name>>` in this case).
answer (`str` or `list[str`]):
The theoretical answer (or list of possible valid answers) to the problem, as code.
"""
def __init__(self, task, inputs, answer):
self.task = task
self.inputs = inputs
self.answer = answer
### The list of problems the agent will be evaluated on.
EVALUATION_TASKS = [
Problem(
task=[
"Is the following `text` (in Spanish) positive or negative?",
"Is the text in the variable `text` (in Spanish) positive or negative?",
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
],
inputs=["text"],
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
),
Problem(
task=[
"Tell me out loud what the `image` contains.",
"Describe the following `image` out loud.",
"Determine what is in the pictured stored in `image` then read it out loud.",
],
inputs=["image"],
answer=[
"text_reader(image_captioner(image))",
"text_reader(image_qa(image, question='What is in the image?'))",
],
),
Problem(
task=[
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
],
inputs=["text_input", "prompt"],
answer="image_transformer(image_generator(text_input), prompt)",
),
Problem(
task=[
"Download the content of `url`, summarize it then generate an image from its content.",
"Use a summary of the web page at `url` to generate an image.",
"Summarize the content of the web page at `url`, and use the result to generate an image.",
],
inputs=["url"],
answer="image_generator(summarizer(text_downloader(url)))",
),
Problem(
task=[
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
],
inputs=["text", "image"],
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
),
Problem(
task=[
"Download the content of `url`, summarize it then read it out loud to me.",
"Read me a summary of the web page at `url`.",
],
inputs=["url"],
answer="text_reader(summarizer(text_downloader(url)))",
),
Problem(
task=[
"Generate an image from the text given in `text_input`.",
],
inputs=["text_input"],
answer="image_generator(text_input)",
),
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer=[
"image_transformer(image, prompt)",
"image_inpainter(image, image_segmenter(image, 'beaver'), prompt)",
],
),
Problem(
task=[
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
"Read me a summary of the the `text` out loud. Transcribe this and translate it in French.",
],
inputs=["text"],
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
),
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
Problem(
task=[
"Download the following file `url`, summarize it in a few words and generate a video from it."
"Fetch the file at this `url`, summarize it, and create an animation out of it."
],
inputs=["url"],
answer="video_generator(summarizer(text_downloader(url)))",
),
]
EVALUATION_CHATS = [
[
Problem(
task=[
"Translate the following `text` from Spanish to English.",
"Translate the following `text` from Spanish to English.",
],
inputs=["text"],
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
),
Problem(
task=[
"Is it positive or negative?",
"Tell me if its positive or negative.",
],
inputs=[],
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
),
],
[
Problem(
task=[
"What does this `image` contain?",
"Describe the following `image`.",
"Determine what is in the picture stored in `image`",
],
inputs=["image"],
answer=[
"description=image_captioner(image)",
"description=image_qa(image, question='What is in the image?')",
],
),
Problem(
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
inputs=[],
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
),
],
[
Problem(
task=[
"Generate an image from the text given in `text_input`.",
"Use the following `text_input` to generate an image",
],
inputs=["text_input"],
answer="image = image_generator(text_input)",
),
Problem(
task=[
"Transform it according to the text in `prompt`.",
"Transform it by using the text in `prompt`.",
],
inputs=["prompt"],
answer="image_transformer(image, prompt)",
),
],
[
Problem(
task=[
"Download the content of `url` and summarize it.",
"Summarize the content of the web page at `url`.",
],
inputs=["url"],
answer="summary = summarizer(text_downloader(url))",
),
Problem(
task=[
"Generate an image from its content.",
"Use the previous result to generate an image.",
],
inputs=[],
answer="image_generator(summary)",
),
],
[
Problem(
task=[
"Translate this Spanish `text` in English.",
"Translate the `text` from Spanish to English.",
],
inputs=["text"],
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
),
Problem(
task=[
"Transform the following `image` using the translated `text`.",
"Use the previous result to transform the following `image`.",
],
inputs=["image"],
answer="image_transformer(image, translated_text)",
),
],
[
Problem(
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
inputs=["url"],
answer="text = text_downloader(url)",
),
Problem(
task=["Summarize this text.", "Summarize this text."],
inputs=[],
answer="summary = summarizer(text)",
),
Problem(
task=["Read it out loud to me.", "Read me the previous result."],
inputs=[],
answer="text_reader(summary)",
),
],
[
Problem(
task=[
"Generate an image from the text given in `text_input`.",
],
inputs=["text_input"],
answer="image_generator(text_input)",
),
],
[
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer=[
"image_transformer(image, prompt)",
"image_inpainter(image, image_segmenter(image, 'beaver'), prompt)",
],
),
],
[
Problem(
task=["Provide me the summary of the `text`.", "Summarize `text`."],
inputs=["text"],
answer="summary = summarizer(text)",
),
Problem(
task=["Read this summary to me.", "Read it out loud."],
inputs=[],
answer="audio = text_reader(summarizer(text))",
),
Problem(
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
inputs=[],
answer="text = transcriber(audio)",
),
Problem(
task=["Translating the last result in French.", "Translate this in French."],
inputs=[],
answer="translator(text, src_lang='English', tgt_lang='French')",
),
],
[
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
],
[
Problem(
task=[
"Download the content of `url` and summarize it.",
"Summarize the content of the web page at `url`.",
],
inputs=["url"],
answer="summary = summarizer(text_downloader(url))",
),
Problem(
task=["generate a video from it.", "Create an animation from the last result."],
inputs=[],
answer="video_generator(summary)",
),
],
]
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
if not isinstance(theoretical_answer, list):
return {name for name in TEST_TOOLS if name in code_answer}
for one_answer, one_code in zip(theoretical_answer, code_answer):
if agent_answer == one_answer:
return {name for name in TEST_TOOLS if name in one_code}
if isinstance(agent_answer, dict):
for one_answer, one_code in zip(theoretical_answer, code_answer):
if one_answer in agent_answer.values():
return {name for name in TEST_TOOLS if name in one_code}
return {name for name in TEST_TOOLS if name in code_answer[0]}
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
tools = BASE_PYTHON_TOOLS.copy()
for name, tool in TEST_TOOLS.items():
if name not in code:
continue
tools[name] = tool
if isinstance(inputs, dict):
inputs = inputs.copy()
elif inputs is not None:
inputs = {inp: f"<<{inp}>>" for inp in inputs}
if state is not None:
state.update(inputs)
else:
state = inputs
try:
return evaluate(code, tools, state)
except InterpretorError as e:
return str(e)
except Exception as e:
if verbose:
print(e)
return None
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
if verbose:
print(agent_answer, theoretical_answer)
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
if agent_answer in theoretical_answer:
if verbose:
print("Perfect!")
return 1
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
if verbose:
print("Almsot perfect, result in state!")
return 0.75
else:
if verbose:
print("Result is not the right one but code executed.")
return 0.3
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
if tools_in_explanation == theoretical_tools:
tool_selection_score = 1.0
tool_selection_errors = None
else:
missing_tools = len(theoretical_tools - tools_in_explanation)
unexpected_tools = len(tools_in_explanation - theoretical_tools)
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
tool_selection_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
tools_in_code = {name for name in TEST_TOOLS if name in code}
if tools_in_code == theoretical_tools:
tool_used_score = 1.0
tool_used_errors = None
else:
missing_tools = len(theoretical_tools - tools_in_code)
unexpected_tools = len(tools_in_code - theoretical_tools)
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
tool_used_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
if score < 1.0:
code_errors = {
"code_produced": code,
"evaluation": agent_answer,
"theoretical_answer": theoretical_answer,
}
else:
code_errors = None
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
"""
Evaluates a new agent on all `EVALUATION_TASKS`.
Example:
```py
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
bads = new_evaluate_agent(agent)
for bad in bads:
print(bad)
```
"""
# Sanity check
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = set(agent_tools) - TEST_TOOLS
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
)
eval_tasks = []
eval_idx = []
for idx, pb in enumerate(EVALUATION_TASKS):
if isinstance(pb.task, list):
eval_tasks.extend(pb.task)
eval_idx.extend([idx] * len(pb.task))
else:
eval_tasks.append(pb.task)
eval_idx.append(idx)
tool_selection_score = 0
tool_used_score = 0
code_score = 0
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
for start_idx in range(0, len(eval_tasks), batch_size):
end_idx = min(start_idx + batch_size, len(eval_tasks))
batch_tasks = eval_tasks[start_idx:end_idx]
results = agent.generate_code(batch_tasks)
for idx, result in enumerate(results):
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
if verbose:
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
explanation, code = clean_code_for_run(result)
# Evaluate agent answer and code answer
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
if isinstance(problem.answer, list):
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
else:
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
)
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
if return_errors:
if errors[0] is not None:
tool_selection_errors[batch_tasks[idx]] = errors[0]
if errors[1] is not None:
tool_used_errors[batch_tasks[idx]] = errors[1]
if errors[2] is not None:
code_errors[batch_tasks[idx]] = errors[2]
scores = {
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
"code score": 100 * (code_score / len(eval_tasks)),
}
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
"""
Evaluates a new agent on all `EVALUATION_CHATS`.
Example:
```py
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
bads = new_evaluate_agent(agent)
for bad in bads:
print(bad)
```
"""
# Sanity check
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = agent_tools - set(TEST_TOOLS)
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
)
tool_selection_score = 0
tool_used_score = 0
code_score = 0
total_steps = 0
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
for chat_problem in EVALUATION_CHATS:
if isinstance(chat_problem[0].task, str):
resolved_problems = [chat_problem]
else:
resolved_problems = [
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
for i in range(len(chat_problem[0].task))
]
for problem in resolved_problems:
agent.prepare_for_new_chat()
agent_state = {}
theoretical_state = (
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
)
for step, step_problem in enumerate(problem):
if verbose:
print(step_problem.task)
total_steps += 1
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
result = agent._generate_one(prompt, stop=["Human:", "====="])
agent.chat_history = prompt + result + "\n"
explanation, code = clean_code_for_chat(result)
if verbose:
print(f"==Explanation from the agent==\n{explanation}")
print(f"\n==Code generated by the agent==\n{code}")
# Evaluate agent answer and code answer
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
answer = step_problem.answer
if isinstance(answer, list):
theoretical_answer = [
evaluate_code(a, step_problem.inputs, state=state)
for a, state in zip(answer, theoretical_state)
]
else:
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
)
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
if return_errors:
if errors[0] is not None:
tool_selection_errors[step_problem.task] = errors[0]
if errors[1] is not None:
tool_used_errors[step_problem.task] = errors[1]
if errors[2] is not None:
code_errors[step_problem.task] = errors[2]
scores = {
"tool selection score": 100 * (tool_selection_score / total_steps),
"tool used score": 100 * (tool_used_score / total_steps),
"code score": 100 * (code_score / total_steps),
}
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores

View File

@ -0,0 +1,51 @@
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import OldRemoteTool, PipelineTool
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
Can you answer this question about the text: '{question}'"""
GENERATIVE_QUESTION_ANSWERING_DESCRIPTION = (
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
)
class GenerativeQuestionAnsweringTool(PipelineTool):
default_checkpoint = "google/flan-t5-base"
description = GENERATIVE_QUESTION_ANSWERING_DESCRIPTION
name = "text_qa"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text", "text"]
outputs = ["text"]
def encode(self, text: str, question: str):
prompt = QA_PROMPT.format(text=text, question=question)
return self.pre_processor(prompt, return_tensors="pt")
def forward(self, inputs):
output_ids = self.model.generate(**inputs)
in_b, _ = inputs["input_ids"].shape
out_b = output_ids.shape[0]
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
class RemoteGenerativeQuestionAnsweringTool(OldRemoteTool):
default_checkpoint = "google/flan-t5-base"
description = GENERATIVE_QUESTION_ANSWERING_DESCRIPTION
def extract_outputs(self, outputs):
return outputs[0]["generated_text"]
def prepare_inputs(self, text, question):
prompt = QA_PROMPT.format(text=text, question=question)
return prompt

View File

@ -0,0 +1,57 @@
import io
from ..models.auto import AutoModelForVision2Seq, AutoProcessor
from ..utils import is_vision_available
from .base import OldRemoteTool, PipelineTool
if is_vision_available():
from PIL import Image
IMAGE_CAPTIONING_DESCRIPTION = (
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
"image to caption, and returns a text that contains the description in English."
)
class ImageCaptioningTool(PipelineTool):
default_checkpoint = "Salesforce/blip-image-captioning-base"
description = IMAGE_CAPTIONING_DESCRIPTION
name = "image_captioner"
pre_processor_class = AutoProcessor
model_class = AutoModelForVision2Seq
inputs = ["image"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ValueError("Pillow must be installed to use the ImageCaptioningTool.")
super().__init__(*args, **kwargs)
def encode(self, image: "Image"):
return self.pre_processor(images=image, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
class RemoteImageCaptioningTool(OldRemoteTool):
default_checkpoint = "Salesforce/blip-image-captioning-large"
description = IMAGE_CAPTIONING_DESCRIPTION
def extract_outputs(self, outputs):
return outputs[0]["generated_text"]
def prepare_inputs(self, image):
if isinstance(image, bytes):
return {"data": image}
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
return {"data": byte_io.getvalue()}

View File

@ -0,0 +1,41 @@
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import is_vision_available
from .base import PipelineTool
if is_vision_available():
from PIL import Image
IMAGE_QUESTION_ANSWERING_DESCRIPTION = (
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
"image containing the information, as well as a `question` which should be the question in English. It returns a "
"text that is the answer to the question."
)
class ImageQuestionAnsweringTool(PipelineTool):
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
description = IMAGE_QUESTION_ANSWERING_DESCRIPTION
name = "image_qa"
pre_processor_class = AutoProcessor
model_class = AutoModelForVisualQuestionAnswering
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ValueError("Pillow must be installed to use the ImageQuestionAnsweringTool.")
super().__init__(*args, **kwargs)
def encode(self, image: "Image", question: str):
return self.pre_processor(image, question, return_tensors="pt")
def forward(self, inputs):
return self.model(**inputs).logits
def decode(self, outputs):
idx = outputs.argmax(-1).item()
return self.model.config.id2label[idx]

View File

@ -0,0 +1,47 @@
import numpy as np
from transformers import AutoProcessor, CLIPSegForImageSegmentation, is_vision_available
from .base import PipelineTool
if is_vision_available():
from PIL import Image
IMAGE_SEGMENTATION_DESCRIPTION = (
"This is a tool that creates a segmentation mask using an image and a prompt. It takes two arguments named "
"`image` which should be the original image, and `prompt` which should be a text describing what should be "
"identified in the segmentation mask. The tool returns the mask as a black-and-white image."
)
class ImageSegmentationTool(PipelineTool):
description = IMAGE_SEGMENTATION_DESCRIPTION
default_checkpoint = "CIDAS/clipseg-rd64-refined"
name = "image_segmenter"
pre_processor_class = AutoProcessor
model_class = CLIPSegForImageSegmentation
inputs = ["image", "text"]
outputs = ["image"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ImportError("Pillow should be installed in order to use the ImageSegmentationTool.")
super().__init__(*args, **kwargs)
def encode(self, image: "Image", prompt: str):
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
def forward(self, inputs):
logits = self.model(**inputs).logits
return logits
def decode(self, outputs):
array = outputs.cpu().detach().numpy()
array[array <= 0] = 0
array[array > 0] = 1
return Image.fromarray((np.dstack([array, array, array]) * 255).astype(np.uint8))

View File

@ -0,0 +1,47 @@
from transformers.tools.base import Tool
from transformers.utils import is_accelerate_available, is_diffusers_available
if is_accelerate_available():
from accelerate.state import PartialState
if is_diffusers_available():
from diffusers import DiffusionPipeline
TEXT_TO_IMAGE_DESCRIPTION = (
"This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
"contains the image description and outputs an image."
)
class TextToImageTool(Tool):
default_checkpoint = "runwayml/stable-diffusion-v1-5"
description = TEXT_TO_IMAGE_DESCRIPTION
def __init__(self, device=None, **hub_kwargs) -> None:
if not is_accelerate_available():
raise ImportError("Accelerate should be installed in order to use tools.")
if not is_diffusers_available():
raise ImportError("Diffusers should be installed in order to use the StableDiffusionTool.")
super().__init__()
self.device = device
self.pipeline = None
self.hub_kwargs = hub_kwargs
def setup(self):
if self.device is None:
self.device = PartialState().default_device
self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
self.pipeline.to(self.device)
self.is_initialized = True
def __call__(self, prompt):
if not self.is_initialized:
self.setup()
return self.pipeline(prompt).images[0]

View File

@ -0,0 +1,26 @@
from .text_classification import TextClassificationTool
LANGUAGE_IDENTIFIER_DESCRIPTION = (
"This is a tool that identifies the language of the text passed as input. It takes one input named `text` and "
"returns the two-letter label of the identified language."
)
class LanguageIdentificationTool(TextClassificationTool):
"""
Example:
```py
from transformers.tools import LanguageIdentificationTool
classifier = LanguageIdentificationTool()
classifier("This is a super nice API!")
```
"""
default_checkpoint = "papluca/xlm-roberta-base-language-detection"
description = LANGUAGE_IDENTIFIER_DESCRIPTION
def decode(self, outputs):
return super().decode(outputs)["label"]

View File

@ -0,0 +1,169 @@
# docstyle-ignore
RUN_PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
Tools:
<<all_tools>>
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Answer:
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
print(f"The translated question is {translated_question}.")
answer = image_qa(image=image, question=translated_question)
print(f"The answer is {answer}")
```
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Answer:
```py
answer = document_qa(document, question="What is the oldest person?")
print(f"The answer is {answer}.")
image = image_generator("A banner showing " + answer)
```
Task: "Generate an image using the text given in the variable `caption`."
I will use the following tool: `image_generator` to generate an image.
Answer:
```py
image = image_generator(prompt=caption)
```
Task: "Summarize the text given in the variable `text` and read it out loud."
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
Answer:
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(summarized_text)
```
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
Answer:
```py
answer = text_qa(text=text, question=question)
print(f"The answer is {answer}.")
image = image_generator(answer)
```
Task: "Caption the following `image`."
I will use the following tool: `image_captioner` to generate a caption for the image.
Answer:
```py
caption = image_captioner(image)
```
Task: "<<prompt>>"
I will use the following"""
# docstyle-ignore
CHAT_PROMPT_TEMPLATE = """Below are a series of dialogues between various people and an AI assistant specialized in coding. The AI assistant tries to be helpful, polite, honest, and humble-but-knowledgeable.
The job of the AI assistant is to come up with a series of simple commands in Python that will perform the task the human wants to perform.
To help with that, the AI assistant has access to a set of tools. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
The AI assistant should first explain the tools it will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. The AI assistant can print intermediate results if it makes sense to do so.
Tools:
<<all_tools>>
=====
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
Assistant: I will use the tool `image_qa` to answer the question on the input image.
```py
answer = image_qa(text=question, image=image)
print(f"The answer is {answer}")
```
Human: I tried this code but it worked but didn't give me a good result. The question is in French
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
print(f"The translated question is {translated_question}.")
answer = image_qa(text=translated_question, image=image)
print(f"The answer is {answer}")
```
=====
Human: Identify the oldest person in the `document`.
Assistant: I will use the tool `document_qa` to find the oldest person in the document.
```py
answer = document_qa(document, question="What is the oldest person?")
print(f"The answer is {answer}.")
```
Human: Can you generate an image with the result?
Assistant: I will use the tool `image_generator` to do that.
```py
image = image_generator(prompt="A banner showing " + answer)
```
=====
Human: Summarize the text given in the variable `text` and read it out loud.
Assistant: I will use the tool `summarizer` to create a summary of the input text, then the tool `text_reader` to read it out loud.
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(text=summary)
```
Human: I got the following error: "The variable `summary` is not defined."
Assistant: My bad! Let's try this code instead.
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(text=summarized_text)
```
Human: It worked! Can you translate the summary in German?
Assistant: I will use the tool `translator` to translate the text in German.
```py
translated_summary = translator(summarized_text, src_lang="English", tgt_lang="German)
```
====
"""
# docstyle-ignore
CHAT_MESSAGE_PROMPT = """
Human: <<task>>
Assistant: """

View File

@ -0,0 +1,210 @@
import ast
import difflib
from collections.abc import Mapping
from typing import Any, Callable, Dict
class InterpretorError(ValueError):
"""
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
operations.
"""
pass
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
of functions.
This function will recurse trough the nodes of the tree provided.
Args:
code (`str`):
The code to evaluate.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated.
"""
try:
expression = ast.parse(code)
except SyntaxError as e:
print("The code generated by the agent is not valid.\n", e)
return
if state is None:
state = {}
result = None
for idx, node in enumerate(expression.body):
try:
line_result = evaluate_ast(node, state, tools)
except InterpretorError as e:
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
if chat_mode:
msg += (
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
)
else:
msg += f":\n{e}"
print(msg)
break
if line_result is not None:
result = line_result
return result
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
"""
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
set of functions.
This function will recurse trough the nodes of the tree provided.
Args:
expression (`ast.AST`):
The code to evaluate, as an abastract syntax tree.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
encounters assignements.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
"""
if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignement which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, tools)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, tools)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
values = [evaluate_ast(v, state, tools) for v in expression.values]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.If):
# If -> execute the right branch
evaluate_if(expression, state, tools)
elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, tools)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, tools)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
def evaluate_assign(assign, state, tools):
var_names = assign.targets
result = evaluate_ast(assign.value, state, tools)
if len(var_names) == 1:
state[var_names[0].id] = result
else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
for var_name, r in zip(var_names, result):
state[var_name.id] = r
return result
def evaluate_call(call, state, tools):
if not isinstance(call.func, ast.Name):
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
f"type {type(call.func)}."
)
func_name = call.func.id
if func_name not in tools:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
)
func = tools[func_name]
# Todo deal with args
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
return func(*args, **kwargs)
def evaluate_subscript(subscript, state, tools):
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
if index in value:
return value[index]
if isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0:
return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools):
if name.id in state:
return state[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
raise InterpretorError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools):
if len(condition.ops) > 1:
raise InterpretorError("Cannot evaluate conditions with multiple operators")
left = evaluate_ast(condition.left, state, tools)
comparator = condition.ops[0]
right = evaluate_ast(condition.comparators[0], state, tools)
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpretorError(f"Operator not supported: {comparator}")
def evaluate_if(if_statement, state, tools):
if evaluate_condition(if_statement.test, state, tools):
for line in if_statement.body:
evaluate_ast(line, state, tools)
else:
for line in if_statement.orelse:
evaluate_ast(line, state, tools)

View File

@ -0,0 +1,39 @@
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from .base import OldRemoteTool, PipelineTool
SPEECH_TO_TEXT_DESCRIPTION = (
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
"transcribed text."
)
class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-base"
description = SPEECH_TO_TEXT_DESCRIPTION
name = "transcriber"
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = ["audio"]
outputs = ["text"]
def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt").input_features
def forward(self, inputs):
return self.model.generate(inputs=inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
class RemoteSpeechToTextTool(OldRemoteTool):
default_checkpoint = "openai/whisper-base"
description = SPEECH_TO_TEXT_DESCRIPTION
def extract_outputs(self, outputs):
return outputs["text"]
def prepare_inputs(self, audio):
return {"data": audio}

View File

@ -0,0 +1,86 @@
import torch
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
from .base import OldRemoteTool, PipelineTool
TEXT_CLASSIFIER_DESCRIPTION = (
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which should "
"be the text to classify, and `labels`, which should be the list of labels to use for classification. It returns "
"the most likely label in the list of provided `labels` for the input text."
)
class TextClassificationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextClassificationTool
classifier = TextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = TEXT_CLASSIFIER_DESCRIPTION
name = "text-classifier"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
inputs = ["text", ["text"]]
outputs = ["text"]
def setup(self):
super().setup()
config = self.model.config
self.entailment_id = -1
for idx, label in config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def encode(self, text, labels):
self._labels = labels
return self.pre_processor(
[text] * len(labels),
[f"This example is {label}" for label in labels],
return_tensors="pt",
padding="max_length",
)
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[:, 2]).item()
return self._labels[label_id]
class RemoteTextClassificationTool(OldRemoteTool):
"""
Example:
```py
from transformers.tools import RemoteTextClassificationTool
classifier = RemoteTextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = TEXT_CLASSIFIER_DESCRIPTION
def prepare_inputs(self, text, labels):
return {"inputs": text, "params": {"candidate_labels": labels}}
def extract_outputs(self, outputs):
label = None
max_score = 0
for lbl, score in zip(outputs["labels"], outputs["scores"]):
if score > max_score:
label = lbl
max_score = score
return label

View File

@ -0,0 +1,59 @@
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import OldRemoteTool, PipelineTool
TEXT_SUMMARIZATION_CESCRIPTION = "This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, and returns a summary of the text."
class TextSummarizationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextSummarizationTool
classifier = TextSummarizationTool()
classifier(long_text)
```
"""
default_checkpoint = "philschmid/bart-large-cnn-samsum"
description = TEXT_SUMMARIZATION_CESCRIPTION
name = "sumamrizer"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text"]
outputs = ["text"]
def encode(self, text):
return self.pre_processor(text, return_tensors="pt", truncation=True)
def forward(self, inputs):
return self.model.generate(**inputs)[0]
def decode(self, outputs):
print(outputs)
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
class RemoteTextSummarizationTool(OldRemoteTool):
"""
Example:
```py
from transformers.tools import RemoteTextClassificationTool
classifier = RemoteTextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "philschmid/flan-t5-base-samsum"
description = TEXT_SUMMARIZATION_CESCRIPTION
def prepare_inputs(self, text):
return {"inputs": text}
def extract_outputs(self, outputs):
return outputs[0]["summary_text"]

View File

@ -0,0 +1,53 @@
import torch
from transformers.utils import is_datasets_available
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from .base import PipelineTool
if is_datasets_available():
from datasets import load_dataset
TEXT_TO_SPEECH_DESCRIPTION = (
"This is a tool that reads an English text out loud. It takes an input named `text` which whould contain the "
"text to read (in English) and returns a waveform object containing the sound."
)
class TextToSpeechTool(PipelineTool):
default_checkpoint = "microsoft/speecht5_tts"
description = TEXT_TO_SPEECH_DESCRIPTION
name = "text_reader"
pre_processor_class = SpeechT5Processor
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = ["text"]
outputs = ["audio"]
def setup(self):
if self.post_processor is None:
self.post_processor = "microsoft/speecht5_hifigan"
super().setup()
def encode(self, text, speaker_embeddings=None):
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
if speaker_embeddings is None:
if not is_datasets_available():
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
def forward(self, inputs):
with torch.no_grad():
return self.model.generate_speech(**inputs)
def decode(self, outputs):
with torch.no_grad():
return self.post_processor(outputs).cpu().detach()

View File

@ -0,0 +1,51 @@
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
TRANSLATION_DESCRIPTION = (
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should be "
"the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, which "
"should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in plain "
"English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
)
class TranslationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = TRANSLATION_DESCRIPTION
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
# TODO add all other languages
lang_to_code = {"French": "fra_Latn", "English": "eng_Latn", "Spanish": "spa_Latn"}
inputs = ["text", "text", "text"]
outputs = ["text"]
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
raise ValueError(f"{src_lang} is not a supported language.")
if tgt_lang not in self.lang_to_code:
raise ValueError(f"{tgt_lang} is not a supported language.")
src_lang = self.lang_to_code[src_lang]
tgt_lang = self.lang_to_code[tgt_lang]
return self.pre_processor._build_translation_inputs(
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)

View File

@ -108,6 +108,7 @@ from .import_utils import (
is_datasets_available,
is_decord_available,
is_detectron2_available,
is_diffusers_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
@ -121,6 +122,7 @@ from .import_utils import (
is_natten_available,
is_ninja_available,
is_onnx_available,
is_opencv_available,
is_optimum_available,
is_pandas_available,
is_peft_available,

View File

@ -235,6 +235,7 @@ def try_to_load_from_cache(
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.
@ -251,6 +252,8 @@ def try_to_load_from_cache(
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
@ -266,7 +269,9 @@ def try_to_load_from_cache(
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
@ -303,6 +308,7 @@ def cached_file(
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
@ -342,6 +348,8 @@ def cached_file(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -393,7 +401,7 @@ def cached_file(
if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
@ -410,6 +418,7 @@ def cached_file(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,

View File

@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_diffusers_available = importlib.util.find_spec("diffusers") is not None
try:
_diffusers_version = importlib_metadata.version("diffusers")
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
_detectron2_version = importlib_metadata.version("detectron2")
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_opencv_available = importlib.util.find_spec("cv2") is not None
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
try:
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
@ -431,6 +442,10 @@ def is_onnx_available():
return _onnx_available
def is_opencv_available():
return _opencv_available
def is_flax_available():
return _flax_available
@ -497,6 +512,10 @@ def is_datasets_available():
return _datasets_available
def is_diffusers_available():
return _diffusers_available
def is_detectron2_available():
return _detectron2_available

0
tests/tools/__init__.py Normal file
View File

View File

@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers.tools import DocumentQuestionAnsweringTool
from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = DocumentQuestionAnsweringTool()
self.tool.setup()
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")

View File

@ -0,0 +1,43 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.tools import GenerativeQuestionAnsweringTool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class GenerativeQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = GenerativeQuestionAnsweringTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")

View File

@ -0,0 +1,40 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from PIL import Image
from transformers.testing_utils import get_tests_dir
from transformers.tools import ImageCaptioningTool
from .test_tools_common import ToolTesterMixin
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = ImageCaptioningTool()
self.tool.setup()
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")

View File

@ -0,0 +1,40 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from PIL import Image
from transformers.testing_utils import get_tests_dir
from transformers.tools import ImageQuestionAnsweringTool
from .test_tools_common import ToolTesterMixin
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = ImageQuestionAnsweringTool()
self.tool.setup()
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")

View File

@ -0,0 +1,40 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from PIL import Image
from transformers.testing_utils import get_tests_dir
from transformers.tools import ImageSegmentationTool
from .test_tools_common import ToolTesterMixin
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = ImageSegmentationTool()
self.tool.setup()
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image, "cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image=image, prompt="cat")
self.assertTrue(isinstance(result, Image.Image))

View File

@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers.tools import SpeechToTextTool
from .test_tools_common import ToolTesterMixin
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = SpeechToTextTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(torch.ones(3000))
self.assertEqual(result, " you")
def test_exact_match_kwarg(self):
result = self.tool(audio=torch.ones(3000))
self.assertEqual(result, " you")

View File

@ -0,0 +1,94 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List
import torch
from PIL import Image
from transformers.testing_utils import get_tests_dir
authorized_types = ["text", "image", "audio"]
def create_inputs(input_types: List[str]):
inputs = []
for input_type in input_types:
if input_type == "text":
inputs.append("Text input")
elif input_type == "image":
inputs.append(
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
)
elif input_type == "audio":
inputs.append(torch.ones(3000))
elif isinstance(input_type, list):
inputs.append(create_inputs(input_type))
else:
raise ValueError(f"Invalid type requested: {input_type}")
return inputs
def output_types(outputs: List):
output_types = []
for output in outputs:
if isinstance(output, str):
output_types.append("text")
elif isinstance(output, Image.Image):
output_types.append("image")
elif isinstance(output, torch.Tensor):
output_types.append("audio")
else:
raise ValueError(f"Invalid output: {output}")
return output_types
class ToolTesterMixin:
def test_inputs_outputs(self):
self.assertTrue(hasattr(self.tool, "inputs"))
self.assertTrue(hasattr(self.tool, "outputs"))
inputs = self.tool.inputs
for _input in inputs:
if isinstance(_input, list):
for __input in _input:
self.assertTrue(__input in authorized_types)
else:
self.assertTrue(_input in authorized_types)
outputs = self.tool.outputs
for _output in outputs:
self.assertTrue(_output in authorized_types)
def test_call(self):
inputs = create_inputs(self.tool.inputs)
outputs = self.tool(*inputs)
# There is a single output
if len(self.tool.outputs) == 1:
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)
def test_common_attributes(self):
self.assertTrue(hasattr(self.tool, "description"))
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
self.assertTrue(self.tool.description.startswith("This is a tool that"))

View File

@ -0,0 +1,34 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.tools import TextClassificationTool
from .test_tools_common import ToolTesterMixin
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = TextClassificationTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("That's quite cool", ["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_kwarg(self):
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
self.assertEqual(result, "positive")

View File

@ -0,0 +1,49 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.tools import TextSummarizationTool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = TextSummarizationTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)

View File

@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.tools import TextToSpeechTool
from .test_tools_common import ToolTesterMixin
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = TextToSpeechTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("hey")
print(result.shape)
# TODO check for real values
def test_exact_match_kwarg(self):
result = self.tool("hey")
print(result.shape)
# TODO check for real values

View File

@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.tools import TranslationTool
from .test_tools_common import ToolTesterMixin, output_types
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = TranslationTool()
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg(self):
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_call(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
outputs = self.tool(*inputs)
# There is a single output
if len(self.tool.outputs) == 1:
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)