mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
118 Commits
740f952218
...
fix_remote
Author | SHA1 | Date | |
---|---|---|---|
ff44fe848d | |||
fe909e6650 | |||
13cc6e534f | |||
9be1243a7d | |||
c2df74b98d | |||
a9c92a7a4a | |||
4258f453fd | |||
e84bcc7997 | |||
4157b35304 | |||
5bdc6c96ec | |||
1d574aef29 | |||
1bd7848d9e | |||
9d706e8342 | |||
7a9137b3bf | |||
4d991ed2ba | |||
69c1abbc55 | |||
43bddabbb4 | |||
a873db85e7 | |||
2711a393d9 | |||
2234ff1126 | |||
2be32455ee | |||
3304f1714c | |||
ac48681ad0 | |||
749a041984 | |||
6293e28d16 | |||
0f1cf429b1 | |||
dc6743db19 | |||
e9b166e8b6 | |||
7f26a004a1 | |||
aedbcd2469 | |||
a740e1a3a4 | |||
4ec7077510 | |||
47040aec29 | |||
3d141d02c0 | |||
4cafdb1098 | |||
fea2e6ce65 | |||
be0a949ad6 | |||
1d00d759f9 | |||
d4fa69d995 | |||
b317a33b6c | |||
1239cdd59c | |||
33903dad2f | |||
85e89c890b | |||
c7609790a7 | |||
325d60ba9c | |||
ed376943bb | |||
0d65be6720 | |||
ff3240b93b | |||
f1dc3b3da1 | |||
ffadafb9f9 | |||
51561d9c13 | |||
a61d5b66ee | |||
b08e4af920 | |||
e4e95f1cf3 | |||
48ccb1aabe | |||
7f4aaf2f7e | |||
2906d2509b | |||
cef1b96ea1 | |||
a255eaa3d0 | |||
895e472484 | |||
8ad381efb7 | |||
6df0c15033 | |||
d8abc32ef6 | |||
85e3acd1d6 | |||
5589cec4c7 | |||
6b64832bbb | |||
34801f1e67 | |||
afbed38879 | |||
0c3fd13884 | |||
290e8306c4 | |||
28fae8b6fa | |||
8b2221a3aa | |||
e819766fdf | |||
74c98e5fef | |||
731029b03b | |||
8c7c78a8f2 | |||
c70005aa11 | |||
0e80e1f75c | |||
2868ea4c9d | |||
f85f4c6b0e | |||
38d71024ab | |||
37f6a02f1e | |||
d88b770d9a | |||
44bea11f6f | |||
d9caa406c2 | |||
c136c19712 | |||
8b1aa74904 | |||
42e16b3761 | |||
b4ebe76fea | |||
cc187725b7 | |||
4cd8879812 | |||
739b7c0d0e | |||
1bf544a036 | |||
3d673f0602 | |||
a463897305 | |||
49766e665d | |||
360cddd78c | |||
9cef2daeae | |||
259b0af144 | |||
6c1f823c83 | |||
77225aa758 | |||
bd2a928909 | |||
d6a36c0343 | |||
c8c766b566 | |||
7fa6db98a2 | |||
0d4600e77d | |||
fd140c0724 | |||
e9e68e856d | |||
e1e47d67f1 | |||
9f3da1d6aa | |||
b8354f845b | |||
67f63d0f91 | |||
3149597939 | |||
98c15ff46f | |||
e00438ffa7 | |||
8d066a02e8 | |||
30221a8d5f | |||
4f9b256b31 |
@ -608,6 +608,7 @@ _import_structure = {
|
||||
"SpecialTokensMixin",
|
||||
"TokenSpan",
|
||||
],
|
||||
"tools": [],
|
||||
"trainer_callback": [
|
||||
"DefaultFlowCallback",
|
||||
"EarlyStoppingCallback",
|
||||
|
@ -115,9 +115,9 @@ def get_relative_import_files(module_file):
|
||||
return all_relative_imports
|
||||
|
||||
|
||||
def check_imports(filename):
|
||||
def get_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
Extracts all the libraries that are imported in a file.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@ -131,9 +131,14 @@ def check_imports(filename):
|
||||
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||
# Only keep the top-level module
|
||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||
return list(set(imports))
|
||||
|
||||
# Unique-ify and test we got them all
|
||||
imports = list(set(imports))
|
||||
|
||||
def check_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
"""
|
||||
imports = get_imports(filename)
|
||||
missing_packages = []
|
||||
for imp in imports:
|
||||
try:
|
||||
@ -169,6 +174,7 @@ def get_cached_module_file(
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
_commit_hash: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -207,6 +213,8 @@ def get_cached_module_file(
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -229,7 +237,7 @@ def get_cached_module_file(
|
||||
else:
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
cached_module = try_to_load_from_cache(
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
|
||||
new_files = []
|
||||
@ -245,6 +253,7 @@ def get_cached_module_file(
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
if not is_local and cached_module != resolved_module_file:
|
||||
@ -309,8 +318,10 @@ def get_cached_module_file(
|
||||
|
||||
if len(new_files) > 0:
|
||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
||||
repo_type_str = "" if repo_type is None else f"{repo_type}/"
|
||||
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
||||
logger.warning(
|
||||
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
|
||||
f"A new version of the following files was downloaded from {url}:\n{new_files}"
|
||||
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
||||
"versions of the code file, you can pin a revision."
|
||||
)
|
||||
@ -328,6 +339,7 @@ def get_class_from_dynamic_module(
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -377,6 +389,8 @@ def get_class_from_dynamic_module(
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -418,6 +432,7 @@ def get_class_from_dynamic_module(
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
|
||||
@ -478,12 +493,17 @@ def custom_object_save(obj, folder, config=None):
|
||||
elif config is not None:
|
||||
_set_auto_map_in_config(config)
|
||||
|
||||
result = []
|
||||
# Copy module file to the output folder.
|
||||
object_file = sys.modules[obj.__module__].__file__
|
||||
dest_file = Path(folder) / (Path(object_file).name)
|
||||
shutil.copy(object_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
# Gather all relative imports recursively and make sure they are copied as well.
|
||||
for needed_file in get_relative_import_files(object_file):
|
||||
dest_file = Path(folder) / (Path(needed_file).name)
|
||||
shutil.copy(needed_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
return result
|
||||
|
@ -64,6 +64,10 @@ class ChannelDimension(ExplicitEnum):
|
||||
LAST = "channels_last"
|
||||
|
||||
|
||||
def is_pil_image(img):
|
||||
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
||||
|
||||
|
||||
def is_valid_image(img):
|
||||
return (
|
||||
(is_vision_available() and isinstance(img, PIL.Image.Image))
|
||||
|
141
src/transformers/tools/README.md
Normal file
141
src/transformers/tools/README.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Do anything with Transformers
|
||||
|
||||
Transformers support all modalities and has many models performing many different types of tasks. But it can get confusing to mix and match them to solve the problem at hand, which is why we have developed a new API of **tools** and **agents**. Given a prompt in natural language and a set of tools, an agent will determine the right code to run with the tools and chain them properly to give you the result you expected.
|
||||
|
||||
Let's start with examples!
|
||||
|
||||
## Examples
|
||||
|
||||
First we need an agent, which is a fancy word to design a LLM tasked with writing the code you will need. We support the traditional openai LLMs but you should really try the opensource alternatives developed by the community which:
|
||||
- clearly state the data they have been trained on
|
||||
- you can run on your own cloud or hardware
|
||||
- have built-in versioning
|
||||
|
||||
<!--TODO for the release we should have a publicly available agent and if token is none, we grab the HF token-->
|
||||
|
||||
```py
|
||||
from transformers.tools import EndpointAgent
|
||||
|
||||
agent = EndpointAgent(
|
||||
url_endpoint=your_endpoint,
|
||||
token=your_hf_token,
|
||||
)
|
||||
|
||||
# from transformers.tools import OpenAiAgent
|
||||
|
||||
# agent = OpenAiAgent(api_key=your_openai_api_key)
|
||||
```
|
||||
|
||||
### Task 1: Classifying text in (almost) any language
|
||||
|
||||
Now to execute a given task, we need to pick a set of tools in `transformers` and send them to our agent. Let's say you want to classify a text in a non-English language, and you have trouble finding a model trained in that language. You can pick a translation tool and a standard text classification tool:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool, TranslationTool
|
||||
|
||||
tools = [TextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
|
||||
```
|
||||
|
||||
then you just run this by your agent:
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"Determine if the following `text` (in French) is positive or negative.",
|
||||
tools=tools,
|
||||
text="J'aime beaucoup Hugging Face!"
|
||||
)
|
||||
```
|
||||
|
||||
Note that you can send any additional inputs in a variable that you named in your prompt (between backticks because it helps the LLM). For text inputs, you can just put them in the prompt:
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"""Determine if the following text: "J'aime beaucoup Hugging Face!" (in French) is positive or negative.""",
|
||||
tools=tools,
|
||||
)
|
||||
```
|
||||
|
||||
In both cases, you should see the agent generate code using your set of tools that is then executed to provide you the answer you were looking for. Neat!
|
||||
|
||||
If you don't have the hardware to run the models translating and classifying the text, you can use the inference API by selecting a remote tool:
|
||||
|
||||
|
||||
```py
|
||||
from transformers.tools import RemoteTextClassificationTool, TranslationTool
|
||||
|
||||
tools = [RemoteTextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
|
||||
|
||||
agent.run(
|
||||
"Determine if the following `text` (in French) is positive or negative.",
|
||||
tools=tools,
|
||||
text="J'aime beaucoup Hugging Face!"
|
||||
)
|
||||
```
|
||||
|
||||
This was still all text-based. Let's now get to something more exciting, combining vision and speech
|
||||
|
||||
## Example 2:
|
||||
|
||||
Let's say we want to hear out loud what is in a given image. There are models that do image-captioning in Transformers, and other models that generate speech from text, but how to combine them? Quite easily:
|
||||
|
||||
<!--TODO add the audio reader tool once it exists-->
|
||||
|
||||
```py
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers.tools import ImageCaptioningTool, TextToSpeechTool
|
||||
|
||||
tools = [ImageCaptioningTool(), TextToSpeechTool()]
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
speech = agent.run(
|
||||
"Tell me out loud what the `image` contains.",
|
||||
tools=tools,
|
||||
image=image
|
||||
)
|
||||
```
|
||||
|
||||
Note that here you have to pass your input as a separate variable since you can't really embed your image in the text.
|
||||
|
||||
In all those examples, we have been using the default checkpoint for a given tool, but you can specify the one you want! For instance, the image-captioning tool uses BLIP by default, but let's upgrade to BLIP-2
|
||||
|
||||
<!--TODO Once it works, use the inference API for BLIP-2 here as it's heavy-->
|
||||
|
||||
```py
|
||||
tools = [ImageCaptioningTool("Salesforce/blip2-opt-2.7b"), TextToSpeechTool()]
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
speech = agent.run(
|
||||
"Tell me out loud what the `image` contains.",
|
||||
tools=tools,
|
||||
image=image
|
||||
)
|
||||
```
|
||||
|
||||
Add more examples?
|
||||
|
||||
## How does it work ?
|
||||
|
||||
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the LLM to give a small sample of code performing a task with a set of tools. This prompt is then completed by the task you give your agent and the description of the tools you give it. This way it gets access to the doc of the tools you are using, especially their expected inputs and outputs and can generate the relevant code.
|
||||
|
||||
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools. Pipelines are more refactored and often combine several tasks in one. Tools are really meant to be focused one very simple task only.
|
||||
|
||||
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools. I hear you screaming "Arbitrary code execution!" in the back, but calm down a minute and let me explain.
|
||||
|
||||
The only functions that can be called are the tools you provided and the print function, so you're already limited in what can be executed. You should be safe if it's limited to Hugging Face tools. Then we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the `run()` method with the additional argument `return_code=True`, in which case the agent will just return the code to execute and you can decide whether to do it or not.
|
||||
|
||||
Note that LLMs are still not *that* good at producing the small amount of code to chain the tools, so we added some logic to fix typos during the evaluation: there are often misnamed variable names or dictionary keys.
|
||||
|
||||
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
|
||||
|
||||
## Future developments
|
||||
|
||||
We hope you're as excited by this new API as we are. Here are a few things we are thinking of adding next if we see the community is interested:
|
||||
- Make the agent pick the tools itself in a first step.
|
||||
- Make the run command more chat-based, so you can copy-paste any error message you see in a next step to have the LLM fix its code, or ask for some improvements.
|
||||
- Add support for more type of agents
|
||||
|
13
src/transformers/tools/__init__.py
Normal file
13
src/transformers/tools/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .agents import Agent, EndpointAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool, Tool, load_tool
|
||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||
from .generative_question_answering import GenerativeQuestionAnsweringTool, RemoteGenerativeQuestionAnsweringTool
|
||||
from .image_captioning import ImageCaptioningTool, RemoteImageCaptioningTool
|
||||
from .image_question_answering import ImageQuestionAnsweringTool
|
||||
from .image_segmentation import ImageSegmentationTool
|
||||
from .language_identifier import LanguageIdentificationTool
|
||||
from .speech_to_text import RemoteSpeechToTextTool, SpeechToTextTool
|
||||
from .text_classification import RemoteTextClassificationTool, TextClassificationTool
|
||||
from .text_summarization import RemoteTextSummarizationTool, TextSummarizationTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
382
src/transformers/tools/agents.py
Normal file
382
src/transformers/tools/agents.py
Normal file
@ -0,0 +1,382 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import logging
|
||||
from .base import TASK_MAPPING, Tool, load_tool
|
||||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
||||
from .python_interpreter import evaluate
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Move to util when this branch is ready to merge
|
||||
def is_openai_available():
|
||||
return importlib.util.find_spec("openai") is not None
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
import openai
|
||||
|
||||
_tools_are_initialized = False
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": print,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTool:
|
||||
task: str
|
||||
description: str
|
||||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-download",
|
||||
"text-to-image",
|
||||
"text-to-video",
|
||||
"image-inpainting",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(organization="huggingface-tools"):
|
||||
spaces = list_spaces(author=organization)
|
||||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, "tool_config.json", repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
for task, task_info in config.items():
|
||||
tools[task_info["name"]] = PreTool(task=task, description=task_info["description"], repo_id=repo_id)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _setup_default_tools():
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
global _tools_are_initialized
|
||||
|
||||
if _tools_are_initialized:
|
||||
return
|
||||
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
|
||||
remote_tools = get_remote_tools()
|
||||
for task_name in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING.get(task_name)
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
description = tool_class.description
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
|
||||
|
||||
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
|
||||
found = False
|
||||
for tool_name, tool in remote_tools.items():
|
||||
if tool.task == task_name:
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise ValueError(f"{task_name} is not implemented on the Hub.")
|
||||
|
||||
_tools_are_initialized = True
|
||||
|
||||
|
||||
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
|
||||
if cached_tools is None:
|
||||
resolved_tools = BASE_PYTHON_TOOLS.copy()
|
||||
else:
|
||||
resolved_tools = cached_tools
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or name in resolved_tools:
|
||||
continue
|
||||
|
||||
if isinstance(tool, Tool):
|
||||
resolved_tools[name] = tool
|
||||
else:
|
||||
resolved_tools[name] = load_tool(tool.task, repo_id=tool.repo_id, remote=remote)
|
||||
|
||||
return resolved_tools
|
||||
|
||||
|
||||
def get_tool_creation_code(code, toolbox, remote=False):
|
||||
code_lines = ["from transformers import load_tool", ""]
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or isinstance(tool, Tool):
|
||||
continue
|
||||
|
||||
line = f'{name} = load_tool("{tool.task}"'
|
||||
if tool.repo_id is not None:
|
||||
line += f', repo_id="{tool.repo_id}"'
|
||||
if remote:
|
||||
line += ", remote=True)"
|
||||
line += ")"
|
||||
code_lines.append(line)
|
||||
|
||||
return "\n".join(code_lines) + "\n"
|
||||
|
||||
|
||||
def clean_code_for_chat(result):
|
||||
lines = result.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
explanation = "\n".join(lines[:idx]).strip()
|
||||
if idx == len(lines):
|
||||
return explanation, None
|
||||
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
while not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
code = "\n".join(lines[start_idx:idx]).strip()
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
def clean_code_for_run(result):
|
||||
result = f"I will use the following {result}"
|
||||
explanation, code = result.split("Answer:")
|
||||
explanation = explanation.strip()
|
||||
code = code.strip()
|
||||
|
||||
code_lines = code.split("\n")
|
||||
if code_lines[0] in ["```", "```py"]:
|
||||
code_lines = code_lines[1:]
|
||||
if code_lines[-1] == "```":
|
||||
code_lines = code_lines[:-1]
|
||||
code = "\n".join(code_lines)
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
_setup_default_tools()
|
||||
|
||||
self.chat_prompt_template = CHAT_MESSAGE_PROMPT if chat_prompt_template is None else chat_prompt_template
|
||||
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
|
||||
self.toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
|
||||
if additional_tools is not None:
|
||||
if isinstance(additional_tools, (list, tuple)):
|
||||
additional_tools = {t.name: t for t in additional_tools}
|
||||
elif not isinstance(additional_tools, dict):
|
||||
additional_tools = {additional_tools.name: additional_tools}
|
||||
|
||||
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
|
||||
self.toolbox.update(additional_tools)
|
||||
if len(replacements) > 1:
|
||||
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
|
||||
logger.warn(
|
||||
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
|
||||
)
|
||||
elif len(replacements) == 1:
|
||||
name = list(replacements.keys())[0]
|
||||
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
|
||||
|
||||
self.prepare_for_new_chat()
|
||||
|
||||
def format_prompt(self, task, chat_mode=False):
|
||||
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
|
||||
if chat_mode:
|
||||
if self.chat_history is None:
|
||||
prompt = CHAT_PROMPT_TEMPLATE.replace("<<all_tools>>", description)
|
||||
else:
|
||||
prompt = self.chat_history
|
||||
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
|
||||
else:
|
||||
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
|
||||
prompt = prompt.replace("<<prompt>>", task)
|
||||
return prompt
|
||||
|
||||
def chat(self, task, return_code=False, remote=False, **kwargs):
|
||||
prompt = self.format_prompt(task, chat_mode=True)
|
||||
result = self._generate_one(prompt, stop=["Human:", "====="])
|
||||
self.chat_history = prompt + result + "\n"
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
if code is not None:
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
self.chat_state.update(kwargs)
|
||||
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def prepare_for_new_chat(self):
|
||||
self.chat_history = None
|
||||
self.chat_state = {}
|
||||
self.cached_tools = None
|
||||
|
||||
def run(self, task, return_code=False, remote=False, **kwargs):
|
||||
prompt = self.format_prompt(task)
|
||||
result = self._generate_one(prompt, stop=["Task:"])
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
return evaluate(code, self.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
|
||||
class OpenAiAgent(Agent):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools.agents import NewOpenAiAgent
|
||||
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=xxx)
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="gpt-3.5-turbo",
|
||||
api_key=None,
|
||||
chat_prompt_template=None,
|
||||
run_prompt_template=None,
|
||||
additional_tools=None,
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAIAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
|
||||
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
|
||||
"xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
self.model = model
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_code(self, task):
|
||||
is_batched = isinstance(task, list)
|
||||
|
||||
if is_batched:
|
||||
prompts = [self.format_prompt(one_task) for one_task in task]
|
||||
else:
|
||||
prompts = [self.format_prompt(task)]
|
||||
|
||||
if "gpt" in self.model:
|
||||
results = [self._chat_generate(prompt, stop="Task:") for prompt in prompts]
|
||||
else:
|
||||
results = self._completion_generate(prompts, stop="Task:")
|
||||
|
||||
return results if is_batched else results[0]
|
||||
|
||||
def _generate_one(self, prompt, stop):
|
||||
if "gpt" in self.model:
|
||||
return self._chat_generate(prompt, stop)
|
||||
else:
|
||||
return self._completion_generate([prompt], stop)[0]
|
||||
|
||||
def _chat_generate(self, prompt, stop):
|
||||
result = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
def _completion_generate(self, prompts, stop):
|
||||
result = openai.Completion.create(
|
||||
model=self.model,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
max_tokens=200,
|
||||
)
|
||||
return [answer["text"] for answer in result["choices"]]
|
||||
|
||||
|
||||
class EndpointAgent(Agent):
|
||||
def __init__(
|
||||
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
|
||||
):
|
||||
self.url_endpoint = url_endpoint
|
||||
if token is None:
|
||||
self.token = f"Bearer {HfFolder().get_token()}"
|
||||
elif token.startswith("Bearer") or token.startswith("Basic"):
|
||||
self.token = token
|
||||
else:
|
||||
self.token = f"Bearer {token}"
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_code(self, task):
|
||||
is_batched = isinstance(task, list)
|
||||
|
||||
if is_batched:
|
||||
prompts = [self.format_prompt(one_task) for one_task in task]
|
||||
else:
|
||||
prompts = [self.format_prompt(task)]
|
||||
|
||||
# Can probably batch those but can't test anymore right now as the endpoint has been limited in length.
|
||||
results = [self._generate_one(prompt) for prompt in prompts]
|
||||
return results if is_batched else results[0]
|
||||
|
||||
def _generate_one(self, prompt, stop):
|
||||
headers = {"Authorization": self.token}
|
||||
inputs = {
|
||||
"inputs": prompt,
|
||||
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
|
||||
}
|
||||
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code == 429:
|
||||
print("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
time.sleep(1)
|
||||
return self._generate_one(prompt)
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[: -len(stop_seq)]
|
||||
return result
|
548
src/transformers/tools/base.py
Normal file
548
src/transformers/tools/base.py
Normal file
@ -0,0 +1,548 @@
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import CommitOperationAdd, HfFolder, InferenceApi, create_commit, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, get_session
|
||||
|
||||
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
|
||||
from ..image_utils import is_pil_image
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
cached_file,
|
||||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import send_to_device
|
||||
|
||||
|
||||
TOOL_CONFIG_FILE = "tool_config.json"
|
||||
|
||||
|
||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||
if repo_type is not None:
|
||||
return repo_type
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
|
||||
return "space"
|
||||
except RepositoryNotFoundError:
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||
return "model"
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||
except Exception:
|
||||
return "model"
|
||||
except Exception:
|
||||
return "space"
|
||||
|
||||
|
||||
APP_FILE_TEMPLATE = """from transformers.tools.base import launch_gradio_demo from {module_name} import {class_name}
|
||||
|
||||
launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
Example of a super 'Tool' class that could live in huggingface_hub
|
||||
"""
|
||||
|
||||
description: str = "This is a tool that ..."
|
||||
name: str = ""
|
||||
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs): # Might become run?
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def setup(self):
|
||||
# Do here any operation that is expensive and needs to be executed before you start using your tool. Such as
|
||||
# loading a big model.
|
||||
self.is_initialized = True
|
||||
|
||||
def save(self, output_dir, task_name=None):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Save module file
|
||||
module_files = custom_object_save(self, output_dir)
|
||||
|
||||
module_name = self.__class__.__module__
|
||||
last_module = module_name.split(".")[-1]
|
||||
full_name = f"{last_module}.{self.__class__.__name__}"
|
||||
|
||||
# Save config file
|
||||
config_file = os.path.join(output_dir, "tool_config.json")
|
||||
if os.path.isfile(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
tool_config = json.load(f)
|
||||
else:
|
||||
tool_config = {}
|
||||
|
||||
if task_name is None:
|
||||
class_name = self.__class__.__name__.replace("Tool", "")
|
||||
chars = [f"_{c.lower()}" if c.isupper() else c for c in class_name]
|
||||
task_name = "".join(chars)[1:]
|
||||
|
||||
tool_config[task_name] = {"tool_class": full_name, "description": self.description, "name": self.name}
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
||||
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
imports = []
|
||||
for module in module_files:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(set(imports))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
|
||||
@classmethod
|
||||
def from_hub(cls, task_or_repo_id, repo_id=None, model_repo_id=None, token=None, remote=False, **kwargs):
|
||||
if remote and model_repo_id is None:
|
||||
raise ValueError("To use this tool remotely, please pass along the url endpoint to `model_repo_id`.")
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"repo_type",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||
if repo_id is None:
|
||||
repo_id = task_or_repo_id
|
||||
task = None
|
||||
else:
|
||||
task = task_or_repo_id
|
||||
|
||||
# Try to get the tool config first.
|
||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
TOOL_CONFIG_FILE,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
is_tool_config = resolved_config_file is not None
|
||||
if resolved_config_file is None:
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
CONFIG_NAME,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||
)
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
if not is_tool_config:
|
||||
if "custom_tools" not in config:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
||||
)
|
||||
custom_tools = config["custom_tools"]
|
||||
else:
|
||||
custom_tools = config
|
||||
if task is None:
|
||||
if len(custom_tools) == 1:
|
||||
task = list(custom_tools.keys())[0]
|
||||
else:
|
||||
tasks_available = "\n".join([f"- {t}" for t in custom_tools.keys()])
|
||||
raise ValueError(f"Please select a task among the one available in {repo_id}:\n{tasks_available}")
|
||||
|
||||
tool_class = custom_tools[task]["tool_class"]
|
||||
tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs)
|
||||
if model_repo_id is not None:
|
||||
repo_id = model_repo_id
|
||||
elif hub_kwargs["repo_type"] == "space":
|
||||
repo_id = None
|
||||
|
||||
if remote:
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
return tool_class(repo_id, token=token, **kwargs)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
use_temp_dir: Optional[bool] = None,
|
||||
commit_message: str = "Upload tool",
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
|
||||
`repo_path_or_name`.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
pushing to a given organization.
|
||||
use_temp_dir (`bool`, *optional*):
|
||||
Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
|
||||
Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
|
||||
commit_message (`str`, *optional*, defaults to `"Upload too"`):
|
||||
Message to commit while pushing.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private.
|
||||
token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unsel, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
if os.path.isdir(repo_id):
|
||||
working_dir = repo_id
|
||||
repo_id = repo_id.split(os.path.sep)[-1]
|
||||
else:
|
||||
working_dir = repo_id.split("/")[-1]
|
||||
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
|
||||
if use_temp_dir is None:
|
||||
use_temp_dir = not os.path.isdir(working_dir)
|
||||
|
||||
with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
|
||||
files_timestamps = self._get_files_timestamps(work_dir)
|
||||
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
|
||||
modified_files = [
|
||||
f
|
||||
for f in os.listdir(work_dir)
|
||||
if f not in files_timestamps or os.path.getmtime(os.path.join(work_dir, f)) > files_timestamps[f]
|
||||
]
|
||||
operations = []
|
||||
for file in modified_files:
|
||||
operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, file), path_in_repo=file))
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
|
||||
return create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
token=token,
|
||||
create_pr=create_pr,
|
||||
repo_type="space",
|
||||
)
|
||||
|
||||
|
||||
class OldRemoteTool(Tool):
|
||||
default_checkpoint = None
|
||||
|
||||
def __init__(self, repo_id=None, token=None):
|
||||
if repo_id is None:
|
||||
repo_id = self.default_checkpoint
|
||||
self.repo_id = repo_id
|
||||
self.client = InferenceApi(repo_id, token=token)
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
if len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
return {"data": args[0]}
|
||||
|
||||
return {"inputs": kwargs}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs)
|
||||
else:
|
||||
outputs = self.client(inputs)
|
||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||
outputs = outputs[0]
|
||||
return self.extract_outputs(outputs)
|
||||
|
||||
|
||||
class RemoteTool(Tool):
|
||||
default_url = None
|
||||
|
||||
def __init__(self, endpoint_url=None, token=None, tool_class=None):
|
||||
if endpoint_url is None:
|
||||
endpoint_url = self.default_url
|
||||
self.endpoint_url = endpoint_url
|
||||
self.client = EndpointClient(endpoint_url, token=token)
|
||||
self.tool_class = tool_class
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
if len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
if is_pil_image(args[0]):
|
||||
byte_io = io.BytesIO()
|
||||
args[0].save(byte_io, format="PNG")
|
||||
return {"inputs": byte_io.getvalue()}
|
||||
return {"inputs": args[0]}
|
||||
|
||||
inputs = kwargs.copy()
|
||||
for key, value in inputs.items():
|
||||
if is_pil_image(value):
|
||||
byte_io = io.BytesIO()
|
||||
value.save(byte_io, format="PNG")
|
||||
inputs[key] = byte_io.getvalue()
|
||||
|
||||
return {"inputs": kwargs}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs, output_image=output_image)
|
||||
else:
|
||||
outputs = self.client(inputs, output_image=output_image)
|
||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||
outputs = outputs[0]
|
||||
return self.extract_outputs(outputs)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
token=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Please install accelerate in order to use this tool.")
|
||||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
||||
self.model = model
|
||||
self.pre_processor = pre_processor
|
||||
self.post_processor = post_processor
|
||||
self.device = device
|
||||
self.device_map = device_map
|
||||
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
||||
if device_map is not None:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.hub_kwargs = hub_kwargs
|
||||
self.hub_kwargs["use_auth_token"] = token
|
||||
|
||||
self.is_initialized = False
|
||||
|
||||
def setup(self):
|
||||
# Instantiate me maybe
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
self.device = list(self.model.hf_device_map.values())[0]
|
||||
else:
|
||||
self.device = get_default_device()
|
||||
|
||||
if self.device_map is None:
|
||||
self.model.to(self.device)
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
|
||||
def encode(self, raw_inputs):
|
||||
return self.pre_processor(raw_inputs)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.post_processor(outputs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||
outputs = self.forward(encoded_inputs)
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
return self.decode(outputs)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
inputs=tool_class.inputs,
|
||||
outputs=tool_class.outputs,
|
||||
title=tool_class.__name__,
|
||||
article=tool.description,
|
||||
).launch()
|
||||
|
||||
|
||||
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
|
||||
def get_default_device():
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
return torch.device("mps")
|
||||
elif torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
"generative-qa": "GenerativeQuestionAnsweringTool",
|
||||
"image-captioning": "ImageCaptioningTool",
|
||||
"image-segmentation": "ImageSegmentationTool",
|
||||
# "language-identification": "LanguageIdentificationTool",
|
||||
"speech-to-text": "SpeechToTextTool",
|
||||
"text-classification": "TextClassificationTool",
|
||||
"text-to-speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
"summarization": "TextSummarizationTool",
|
||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
||||
}
|
||||
|
||||
|
||||
def load_tool(task_or_repo_id, repo_id=None, remote=False, token=None, **kwargs):
|
||||
if task_or_repo_id in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
|
||||
if remote:
|
||||
return RemoteTool(repo_id, token=token, tool_class=tool_class)
|
||||
else:
|
||||
return tool_class(repo_id, token=token, **kwargs)
|
||||
else:
|
||||
return Tool.from_hub(task_or_repo_id, repo_id=repo_id, token=token, remote=remote, **kwargs)
|
||||
|
||||
|
||||
def add_description(description):
|
||||
"""
|
||||
A decorator that adds a description to a function.
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
func.description = description
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
## Will move to the Hub
|
||||
class EndpointClient:
|
||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder().get_token()
|
||||
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
data: Optional[bytes] = None,
|
||||
output_image: bool = False,
|
||||
) -> Any:
|
||||
# Build payload
|
||||
payload = {}
|
||||
if inputs:
|
||||
payload["inputs"] = inputs
|
||||
if params:
|
||||
payload["parameters"] = params
|
||||
|
||||
# Make API call
|
||||
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
||||
|
||||
# By default, parse the response for the user.
|
||||
if output_image:
|
||||
if not is_vision_available():
|
||||
raise ImportError(
|
||||
f"Task '{self.task}' returned as image but Pillow is not installed."
|
||||
" Please install it (`pip install Pillow`) or pass"
|
||||
" `raw_response=True` to get the raw `Response` object and parse"
|
||||
" the image by yourself."
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
return Image.open(io.BytesIO(base64.b64decode(response.content)))
|
||||
else:
|
||||
return response.json()
|
69
src/transformers/tools/document_question_answering.py
Normal file
69
src/transformers/tools/document_question_answering.py
Normal file
@ -0,0 +1,69 @@
|
||||
import re
|
||||
|
||||
from transformers import VisionEncoderDecoderModel
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
DOCUMENT_QUESTION_ANSWERING_DESCRIPTION = (
|
||||
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which should be the "
|
||||
"document containing the information, as well as a `question` that is the question about the document. It returns a text "
|
||||
"that contains the answer to the question."
|
||||
)
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = DOCUMENT_QUESTION_ANSWERING_DESCRIPTION
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
prompt = task_prompt.replace("{user_input}", question)
|
||||
decoder_input_ids = self.pre_processor.tokenizer(
|
||||
prompt, add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids
|
||||
pixel_values = self.pre_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(
|
||||
inputs["pixel_values"].to(self.device),
|
||||
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
||||
max_length=self.model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
).sequences
|
||||
|
||||
def decode(self, outputs):
|
||||
sequence = self.pre_processor.batch_decode(outputs)[0]
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "").replace(
|
||||
self.pre_processor.tokenizer.pad_token, ""
|
||||
)
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
sequence = self.pre_processor.token2json(sequence)
|
||||
|
||||
return sequence["answer"]
|
686
src/transformers/tools/evaluate_agent.py
Normal file
686
src/transformers/tools/evaluate_agent.py
Normal file
@ -0,0 +1,686 @@
|
||||
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run
|
||||
from .python_interpreter import InterpretorError, evaluate
|
||||
|
||||
|
||||
### Fake tools for test
|
||||
def classifier(text, labels):
|
||||
return f"This is the classification of {text} along {labels}."
|
||||
|
||||
|
||||
def translator(text, src_lang, tgt_lang):
|
||||
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
|
||||
|
||||
|
||||
def speaker(text):
|
||||
return f"This is actually a sound reading {text}."
|
||||
|
||||
|
||||
def transcriber(audio):
|
||||
if "sound" not in audio:
|
||||
raise ValueError(f"`audio` ({audio}) is not a sound.")
|
||||
return f"This is the transcribed text from {audio}."
|
||||
|
||||
|
||||
def image_generator(prompt):
|
||||
return f"This is actually an image representing {prompt}."
|
||||
|
||||
|
||||
def image_captioner(image):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a description of {image}."
|
||||
|
||||
|
||||
def image_transformer(image, prompt):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a transformation of {image} according to {prompt}."
|
||||
|
||||
|
||||
def question_answerer(text, question):
|
||||
return f"This is the answer to {question} from {text}."
|
||||
|
||||
|
||||
def image_qa(image, question):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is the answer to {question} from {image}."
|
||||
|
||||
|
||||
def text_downloader(url):
|
||||
return f"This is the content of {url}."
|
||||
|
||||
|
||||
def summarizer(text):
|
||||
return f"This is a summary of {text}."
|
||||
|
||||
|
||||
def video_generator(prompt, seconds=2):
|
||||
return f"A video of {prompt}"
|
||||
|
||||
|
||||
def document_qa(image, question):
|
||||
return f"This is the answer to {question} from the document {image}."
|
||||
|
||||
|
||||
def image_segmenter(image, prompt):
|
||||
return f"This is the mask of {prompt} in {image}"
|
||||
|
||||
|
||||
def image_inpainter(image, mask, prompt):
|
||||
return f"This is the inpainted of {image} using prompt {prompt} and mask {mask}"
|
||||
|
||||
|
||||
TEST_TOOLS = {
|
||||
"text_classifier": classifier,
|
||||
"translator": translator,
|
||||
"text_reader": speaker,
|
||||
"summarizer": summarizer,
|
||||
"transcriber": transcriber,
|
||||
"image_generator": image_generator,
|
||||
"image_captioner": image_captioner,
|
||||
"image_transformer": image_transformer,
|
||||
"text_qa": question_answerer,
|
||||
"text_downloader": text_downloader,
|
||||
"image_qa": image_qa,
|
||||
"video_generator": video_generator,
|
||||
"document_qa": document_qa,
|
||||
"image_segmenter": image_segmenter,
|
||||
"image_inpainter": image_inpainter,
|
||||
}
|
||||
|
||||
|
||||
class Problem:
|
||||
"""
|
||||
A class regrouping all the information to solve a problem on which we will evaluate agents.
|
||||
|
||||
Args:
|
||||
task (`str` ou `list[str]`):
|
||||
One or several descriptions of the task to perform. If a list, it should contain variations on the
|
||||
phrasing, but for the same task.
|
||||
inputs (`list[str]` or `dict[str, str]`):
|
||||
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
|
||||
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
|
||||
inputs expected (the value used will be `<<input_name>>` in this case).
|
||||
answer (`str` or `list[str`]):
|
||||
The theoretical answer (or list of possible valid answers) to the problem, as code.
|
||||
"""
|
||||
|
||||
def __init__(self, task, inputs, answer):
|
||||
self.task = task
|
||||
self.inputs = inputs
|
||||
self.answer = answer
|
||||
|
||||
|
||||
### The list of problems the agent will be evaluated on.
|
||||
EVALUATION_TASKS = [
|
||||
Problem(
|
||||
task=[
|
||||
"Is the following `text` (in Spanish) positive or negative?",
|
||||
"Is the text in the variable `text` (in Spanish) positive or negative?",
|
||||
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Tell me out loud what the `image` contains.",
|
||||
"Describe the following `image` out loud.",
|
||||
"Determine what is in the pictured stored in `image` then read it out loud.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"text_reader(image_captioner(image))",
|
||||
"text_reader(image_qa(image, question='What is in the image?'))",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
|
||||
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["text_input", "prompt"],
|
||||
answer="image_transformer(image_generator(text_input), prompt)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then generate an image from its content.",
|
||||
"Use a summary of the web page at `url` to generate an image.",
|
||||
"Summarize the content of the web page at `url`, and use the result to generate an image.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="image_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
|
||||
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
|
||||
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
|
||||
],
|
||||
inputs=["text", "image"],
|
||||
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then read it out loud to me.",
|
||||
"Read me a summary of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="text_reader(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer=[
|
||||
"image_transformer(image, prompt)",
|
||||
"image_inpainter(image, image_segmenter(image, 'beaver'), prompt)",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
|
||||
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
|
||||
"Read me a summary of the the `text` out loud. Transcribe this and translate it in French.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the following file `url`, summarize it in a few words and generate a video from it."
|
||||
"Fetch the file at this `url`, summarize it, and create an animation out of it."
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="video_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
EVALUATION_CHATS = [
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Is it positive or negative?",
|
||||
"Tell me if its positive or negative.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"What does this `image` contain?",
|
||||
"Describe the following `image`.",
|
||||
"Determine what is in the picture stored in `image`",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"description=image_captioner(image)",
|
||||
"description=image_qa(image, question='What is in the image?')",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
"Use the following `text_input` to generate an image",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image = image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform it according to the text in `prompt`.",
|
||||
"Transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from its content.",
|
||||
"Use the previous result to generate an image.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="image_generator(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate this Spanish `text` in English.",
|
||||
"Translate the `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the translated `text`.",
|
||||
"Use the previous result to transform the following `image`.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer="image_transformer(image, translated_text)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
|
||||
inputs=["url"],
|
||||
answer="text = text_downloader(url)",
|
||||
),
|
||||
Problem(
|
||||
task=["Summarize this text.", "Summarize this text."],
|
||||
inputs=[],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read it out loud to me.", "Read me the previous result."],
|
||||
inputs=[],
|
||||
answer="text_reader(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer=[
|
||||
"image_transformer(image, prompt)",
|
||||
"image_inpainter(image, image_segmenter(image, 'beaver'), prompt)",
|
||||
],
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Provide me the summary of the `text`.", "Summarize `text`."],
|
||||
inputs=["text"],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read this summary to me.", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer="audio = text_reader(summarizer(text))",
|
||||
),
|
||||
Problem(
|
||||
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
|
||||
inputs=[],
|
||||
answer="text = transcriber(audio)",
|
||||
),
|
||||
Problem(
|
||||
task=["Translating the last result in French.", "Translate this in French."],
|
||||
inputs=[],
|
||||
answer="translator(text, src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=["generate a video from it.", "Create an animation from the last result."],
|
||||
inputs=[],
|
||||
answer="video_generator(summary)",
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
|
||||
if not isinstance(theoretical_answer, list):
|
||||
return {name for name in TEST_TOOLS if name in code_answer}
|
||||
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if agent_answer == one_answer:
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
if isinstance(agent_answer, dict):
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if one_answer in agent_answer.values():
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
return {name for name in TEST_TOOLS if name in code_answer[0]}
|
||||
|
||||
|
||||
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
|
||||
tools = BASE_PYTHON_TOOLS.copy()
|
||||
for name, tool in TEST_TOOLS.items():
|
||||
if name not in code:
|
||||
continue
|
||||
tools[name] = tool
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
inputs = inputs.copy()
|
||||
elif inputs is not None:
|
||||
inputs = {inp: f"<<{inp}>>" for inp in inputs}
|
||||
|
||||
if state is not None:
|
||||
state.update(inputs)
|
||||
else:
|
||||
state = inputs
|
||||
|
||||
try:
|
||||
return evaluate(code, tools, state)
|
||||
except InterpretorError as e:
|
||||
return str(e)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
|
||||
if verbose:
|
||||
print(agent_answer, theoretical_answer)
|
||||
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
|
||||
|
||||
if agent_answer in theoretical_answer:
|
||||
if verbose:
|
||||
print("Perfect!")
|
||||
return 1
|
||||
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
|
||||
if verbose:
|
||||
print("Almsot perfect, result in state!")
|
||||
return 0.75
|
||||
else:
|
||||
if verbose:
|
||||
print("Result is not the right one but code executed.")
|
||||
return 0.3
|
||||
|
||||
|
||||
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
|
||||
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
|
||||
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
|
||||
if tools_in_explanation == theoretical_tools:
|
||||
tool_selection_score = 1.0
|
||||
tool_selection_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_explanation)
|
||||
unexpected_tools = len(tools_in_explanation - theoretical_tools)
|
||||
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_selection_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
tools_in_code = {name for name in TEST_TOOLS if name in code}
|
||||
if tools_in_code == theoretical_tools:
|
||||
tool_used_score = 1.0
|
||||
tool_used_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_code)
|
||||
unexpected_tools = len(tools_in_code - theoretical_tools)
|
||||
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_used_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
|
||||
if score < 1.0:
|
||||
code_errors = {
|
||||
"code_produced": code,
|
||||
"evaluation": agent_answer,
|
||||
"theoretical_answer": theoretical_answer,
|
||||
}
|
||||
else:
|
||||
code_errors = None
|
||||
|
||||
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
|
||||
|
||||
|
||||
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_TASKS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = set(agent_tools) - TEST_TOOLS
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
eval_tasks = []
|
||||
eval_idx = []
|
||||
for idx, pb in enumerate(EVALUATION_TASKS):
|
||||
if isinstance(pb.task, list):
|
||||
eval_tasks.extend(pb.task)
|
||||
eval_idx.extend([idx] * len(pb.task))
|
||||
else:
|
||||
eval_tasks.append(pb.task)
|
||||
eval_idx.append(idx)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for start_idx in range(0, len(eval_tasks), batch_size):
|
||||
end_idx = min(start_idx + batch_size, len(eval_tasks))
|
||||
batch_tasks = eval_tasks[start_idx:end_idx]
|
||||
|
||||
results = agent.generate_code(batch_tasks)
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
|
||||
if verbose:
|
||||
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
|
||||
if isinstance(problem.answer, list):
|
||||
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[batch_tasks[idx]] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[batch_tasks[idx]] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[batch_tasks[idx]] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
|
||||
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
|
||||
"code score": 100 * (code_score / len(eval_tasks)),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_CHATS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = agent_tools - set(TEST_TOOLS)
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
total_steps = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for chat_problem in EVALUATION_CHATS:
|
||||
if isinstance(chat_problem[0].task, str):
|
||||
resolved_problems = [chat_problem]
|
||||
else:
|
||||
resolved_problems = [
|
||||
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
|
||||
for i in range(len(chat_problem[0].task))
|
||||
]
|
||||
for problem in resolved_problems:
|
||||
agent.prepare_for_new_chat()
|
||||
agent_state = {}
|
||||
theoretical_state = (
|
||||
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
|
||||
)
|
||||
|
||||
for step, step_problem in enumerate(problem):
|
||||
if verbose:
|
||||
print(step_problem.task)
|
||||
total_steps += 1
|
||||
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
|
||||
result = agent._generate_one(prompt, stop=["Human:", "====="])
|
||||
agent.chat_history = prompt + result + "\n"
|
||||
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
if verbose:
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
print(f"\n==Code generated by the agent==\n{code}")
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
|
||||
|
||||
answer = step_problem.answer
|
||||
if isinstance(answer, list):
|
||||
theoretical_answer = [
|
||||
evaluate_code(a, step_problem.inputs, state=state)
|
||||
for a, state in zip(answer, theoretical_state)
|
||||
]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[step_problem.task] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[step_problem.task] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[step_problem.task] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / total_steps),
|
||||
"tool used score": 100 * (tool_used_score / total_steps),
|
||||
"code score": 100 * (code_score / total_steps),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
51
src/transformers/tools/generative_question_answering.py
Normal file
51
src/transformers/tools/generative_question_answering.py
Normal file
@ -0,0 +1,51 @@
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import OldRemoteTool, PipelineTool
|
||||
|
||||
|
||||
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
|
||||
|
||||
Can you answer this question about the text: '{question}'"""
|
||||
|
||||
|
||||
GENERATIVE_QUESTION_ANSWERING_DESCRIPTION = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
|
||||
|
||||
class GenerativeQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = GENERATIVE_QUESTION_ANSWERING_DESCRIPTION
|
||||
name = "text_qa"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text: str, question: str):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return self.pre_processor(prompt, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
output_ids = self.model.generate(**inputs)
|
||||
|
||||
in_b, _ = inputs["input_ids"].shape
|
||||
out_b = output_ids.shape[0]
|
||||
|
||||
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
|
||||
class RemoteGenerativeQuestionAnsweringTool(OldRemoteTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = GENERATIVE_QUESTION_ANSWERING_DESCRIPTION
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs[0]["generated_text"]
|
||||
|
||||
def prepare_inputs(self, text, question):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return prompt
|
57
src/transformers/tools/image_captioning.py
Normal file
57
src/transformers/tools/image_captioning.py
Normal file
@ -0,0 +1,57 @@
|
||||
import io
|
||||
|
||||
from ..models.auto import AutoModelForVision2Seq, AutoProcessor
|
||||
from ..utils import is_vision_available
|
||||
from .base import OldRemoteTool, PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
IMAGE_CAPTIONING_DESCRIPTION = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
|
||||
|
||||
class ImageCaptioningTool(PipelineTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
description = IMAGE_CAPTIONING_DESCRIPTION
|
||||
name = "image_captioner"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVision2Seq
|
||||
|
||||
inputs = ["image"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the ImageCaptioningTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image"):
|
||||
return self.pre_processor(images=image, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
|
||||
|
||||
class RemoteImageCaptioningTool(OldRemoteTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-large"
|
||||
description = IMAGE_CAPTIONING_DESCRIPTION
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs[0]["generated_text"]
|
||||
|
||||
def prepare_inputs(self, image):
|
||||
if isinstance(image, bytes):
|
||||
return {"data": image}
|
||||
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
return {"data": byte_io.getvalue()}
|
41
src/transformers/tools/image_question_answering.py
Normal file
41
src/transformers/tools/image_question_answering.py
Normal file
@ -0,0 +1,41 @@
|
||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
IMAGE_QUESTION_ANSWERING_DESCRIPTION = (
|
||||
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
|
||||
"image containing the information, as well as a `question` which should be the question in English. It returns a "
|
||||
"text that is the answer to the question."
|
||||
)
|
||||
|
||||
|
||||
class ImageQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
||||
description = IMAGE_QUESTION_ANSWERING_DESCRIPTION
|
||||
name = "image_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVisualQuestionAnswering
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the ImageQuestionAnsweringTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
return self.pre_processor(image, question, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model(**inputs).logits
|
||||
|
||||
def decode(self, outputs):
|
||||
idx = outputs.argmax(-1).item()
|
||||
return self.model.config.id2label[idx]
|
47
src/transformers/tools/image_segmentation.py
Normal file
47
src/transformers/tools/image_segmentation.py
Normal file
@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation, is_vision_available
|
||||
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
IMAGE_SEGMENTATION_DESCRIPTION = (
|
||||
"This is a tool that creates a segmentation mask using an image and a prompt. It takes two arguments named "
|
||||
"`image` which should be the original image, and `prompt` which should be a text describing what should be "
|
||||
"identified in the segmentation mask. The tool returns the mask as a black-and-white image."
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
description = IMAGE_SEGMENTATION_DESCRIPTION
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
name = "image_segmenter"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = CLIPSegForImageSegmentation
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["image"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ImportError("Pillow should be installed in order to use the ImageSegmentationTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", prompt: str):
|
||||
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
|
||||
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
logits = self.model(**inputs).logits
|
||||
return logits
|
||||
|
||||
def decode(self, outputs):
|
||||
array = outputs.cpu().detach().numpy()
|
||||
array[array <= 0] = 0
|
||||
array[array > 0] = 1
|
||||
return Image.fromarray((np.dstack([array, array, array]) * 255).astype(np.uint8))
|
47
src/transformers/tools/image_upscaling.py
Normal file
47
src/transformers/tools/image_upscaling.py
Normal file
@ -0,0 +1,47 @@
|
||||
from transformers.tools.base import Tool
|
||||
from transformers.utils import is_accelerate_available, is_diffusers_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if is_diffusers_available():
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
TEXT_TO_IMAGE_DESCRIPTION = (
|
||||
"This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
|
||||
"contains the image description and outputs an image."
|
||||
)
|
||||
|
||||
|
||||
class TextToImageTool(Tool):
|
||||
default_checkpoint = "runwayml/stable-diffusion-v1-5"
|
||||
description = TEXT_TO_IMAGE_DESCRIPTION
|
||||
|
||||
def __init__(self, device=None, **hub_kwargs) -> None:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Accelerate should be installed in order to use tools.")
|
||||
if not is_diffusers_available():
|
||||
raise ImportError("Diffusers should be installed in order to use the StableDiffusionTool.")
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.pipeline = None
|
||||
self.hub_kwargs = hub_kwargs
|
||||
|
||||
def setup(self):
|
||||
if self.device is None:
|
||||
self.device = PartialState().default_device
|
||||
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
|
||||
self.pipeline.to(self.device)
|
||||
|
||||
self.is_initialized = True
|
||||
|
||||
def __call__(self, prompt):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
return self.pipeline(prompt).images[0]
|
26
src/transformers/tools/language_identifier.py
Normal file
26
src/transformers/tools/language_identifier.py
Normal file
@ -0,0 +1,26 @@
|
||||
from .text_classification import TextClassificationTool
|
||||
|
||||
|
||||
LANGUAGE_IDENTIFIER_DESCRIPTION = (
|
||||
"This is a tool that identifies the language of the text passed as input. It takes one input named `text` and "
|
||||
"returns the two-letter label of the identified language."
|
||||
)
|
||||
|
||||
|
||||
class LanguageIdentificationTool(TextClassificationTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import LanguageIdentificationTool
|
||||
|
||||
classifier = LanguageIdentificationTool()
|
||||
classifier("This is a super nice API!")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "papluca/xlm-roberta-base-language-detection"
|
||||
description = LANGUAGE_IDENTIFIER_DESCRIPTION
|
||||
|
||||
def decode(self, outputs):
|
||||
return super().decode(outputs)["label"]
|
169
src/transformers/tools/prompts.py
Normal file
169
src/transformers/tools/prompts.py
Normal file
@ -0,0 +1,169 @@
|
||||
# docstyle-ignore
|
||||
RUN_PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
Task: "Generate an image using the text given in the variable `caption`."
|
||||
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
image = image_generator(prompt=caption)
|
||||
```
|
||||
|
||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||
|
||||
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(summarized_text)
|
||||
```
|
||||
|
||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||
|
||||
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = text_qa(text=text, question=question)
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
Task: "Caption the following `image`."
|
||||
|
||||
I will use the following tool: `image_captioner` to generate a caption for the image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
caption = image_captioner(image)
|
||||
```
|
||||
|
||||
Task: "<<prompt>>"
|
||||
|
||||
I will use the following"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_PROMPT_TEMPLATE = """Below are a series of dialogues between various people and an AI assistant specialized in coding. The AI assistant tries to be helpful, polite, honest, and humble-but-knowledgeable.
|
||||
|
||||
The job of the AI assistant is to come up with a series of simple commands in Python that will perform the task the human wants to perform.
|
||||
To help with that, the AI assistant has access to a set of tools. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
The AI assistant should first explain the tools it will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. The AI assistant can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code but it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Identify the oldest person in the `document`.
|
||||
|
||||
Assistant: I will use the tool `document_qa` to find the oldest person in the document.
|
||||
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
```
|
||||
|
||||
Human: Can you generate an image with the result?
|
||||
|
||||
Assistant: I will use the tool `image_generator` to do that.
|
||||
|
||||
```py
|
||||
image = image_generator(prompt="A banner showing " + answer)
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Summarize the text given in the variable `text` and read it out loud.
|
||||
|
||||
Assistant: I will use the tool `summarizer` to create a summary of the input text, then the tool `text_reader` to read it out loud.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summary)
|
||||
```
|
||||
|
||||
Human: I got the following error: "The variable `summary` is not defined."
|
||||
|
||||
Assistant: My bad! Let's try this code instead.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summarized_text)
|
||||
```
|
||||
|
||||
Human: It worked! Can you translate the summary in German?
|
||||
|
||||
Assistant: I will use the tool `translator` to translate the text in German.
|
||||
|
||||
```py
|
||||
translated_summary = translator(summarized_text, src_lang="English", tgt_lang="German)
|
||||
```
|
||||
|
||||
====
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_MESSAGE_PROMPT = """
|
||||
Human: <<task>>
|
||||
|
||||
Assistant: """
|
210
src/transformers/tools/python_interpreter.py
Normal file
210
src/transformers/tools/python_interpreter.py
Normal file
@ -0,0 +1,210 @@
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
print("The code generated by the agent is not valid.\n", e)
|
||||
return
|
||||
if state is None:
|
||||
state = {}
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
|
||||
if chat_mode:
|
||||
msg += (
|
||||
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
|
||||
)
|
||||
else:
|
||||
msg += f":\n{e}"
|
||||
print(msg)
|
||||
break
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
evaluate_if(expression, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
|
||||
if len(var_names) == 1:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
|
||||
f"type {type(call.func)}."
|
||||
)
|
||||
func_name = call.func.id
|
||||
if func_name not in tools:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
func = tools[func_name]
|
||||
# Todo deal with args
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpretorError("Cannot evaluate conditions with multiple operators")
|
||||
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparator = condition.ops[0]
|
||||
right = evaluate_ast(condition.comparators[0], state, tools)
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {comparator}")
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
if evaluate_condition(if_statement.test, state, tools):
|
||||
for line in if_statement.body:
|
||||
evaluate_ast(line, state, tools)
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
evaluate_ast(line, state, tools)
|
39
src/transformers/tools/speech_to_text.py
Normal file
39
src/transformers/tools/speech_to_text.py
Normal file
@ -0,0 +1,39 @@
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .base import OldRemoteTool, PipelineTool
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_DESCRIPTION = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = SPEECH_TO_TEXT_DESCRIPTION
|
||||
name = "transcriber"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = ["audio"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt").input_features
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs=inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
class RemoteSpeechToTextTool(OldRemoteTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = SPEECH_TO_TEXT_DESCRIPTION
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs["text"]
|
||||
|
||||
def prepare_inputs(self, audio):
|
||||
return {"data": audio}
|
86
src/transformers/tools/text_classification.py
Normal file
86
src/transformers/tools/text_classification.py
Normal file
@ -0,0 +1,86 @@
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from .base import OldRemoteTool, PipelineTool
|
||||
|
||||
|
||||
TEXT_CLASSIFIER_DESCRIPTION = (
|
||||
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which should "
|
||||
"be the text to classify, and `labels`, which should be the list of labels to use for classification. It returns "
|
||||
"the most likely label in the list of provided `labels` for the input text."
|
||||
)
|
||||
|
||||
|
||||
class TextClassificationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
classifier = TextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/bart-large-mnli"
|
||||
description = TEXT_CLASSIFIER_DESCRIPTION
|
||||
name = "text-classifier"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
inputs = ["text", ["text"]]
|
||||
outputs = ["text"]
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
config = self.model.config
|
||||
self.entailment_id = -1
|
||||
for idx, label in config.id2label.items():
|
||||
if label.lower().startswith("entail"):
|
||||
self.entailment_id = int(idx)
|
||||
if self.entailment_id == -1:
|
||||
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
|
||||
|
||||
def encode(self, text, labels):
|
||||
self._labels = labels
|
||||
return self.pre_processor(
|
||||
[text] * len(labels),
|
||||
[f"This example is {label}" for label in labels],
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
def decode(self, outputs):
|
||||
logits = outputs.logits
|
||||
label_id = torch.argmax(logits[:, 2]).item()
|
||||
return self._labels[label_id]
|
||||
|
||||
|
||||
class RemoteTextClassificationTool(OldRemoteTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import RemoteTextClassificationTool
|
||||
|
||||
classifier = RemoteTextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/bart-large-mnli"
|
||||
description = TEXT_CLASSIFIER_DESCRIPTION
|
||||
|
||||
def prepare_inputs(self, text, labels):
|
||||
return {"inputs": text, "params": {"candidate_labels": labels}}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
label = None
|
||||
max_score = 0
|
||||
for lbl, score in zip(outputs["labels"], outputs["scores"]):
|
||||
if score > max_score:
|
||||
label = lbl
|
||||
max_score = score
|
||||
|
||||
return label
|
59
src/transformers/tools/text_summarization.py
Normal file
59
src/transformers/tools/text_summarization.py
Normal file
@ -0,0 +1,59 @@
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import OldRemoteTool, PipelineTool
|
||||
|
||||
|
||||
TEXT_SUMMARIZATION_CESCRIPTION = "This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, and returns a summary of the text."
|
||||
|
||||
|
||||
class TextSummarizationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextSummarizationTool
|
||||
|
||||
classifier = TextSummarizationTool()
|
||||
classifier(long_text)
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "philschmid/bart-large-cnn-samsum"
|
||||
description = TEXT_SUMMARIZATION_CESCRIPTION
|
||||
name = "sumamrizer"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor(text, return_tensors="pt", truncation=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)[0]
|
||||
|
||||
def decode(self, outputs):
|
||||
print(outputs)
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
|
||||
class RemoteTextSummarizationTool(OldRemoteTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import RemoteTextClassificationTool
|
||||
|
||||
classifier = RemoteTextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "philschmid/flan-t5-base-samsum"
|
||||
description = TEXT_SUMMARIZATION_CESCRIPTION
|
||||
|
||||
def prepare_inputs(self, text):
|
||||
return {"inputs": text}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs[0]["summary_text"]
|
53
src/transformers/tools/text_to_speech.py
Normal file
53
src/transformers/tools/text_to_speech.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
from transformers.utils import is_datasets_available
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
TEXT_TO_SPEECH_DESCRIPTION = (
|
||||
"This is a tool that reads an English text out loud. It takes an input named `text` which whould contain the "
|
||||
"text to read (in English) and returns a waveform object containing the sound."
|
||||
)
|
||||
|
||||
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
description = TEXT_TO_SPEECH_DESCRIPTION
|
||||
name = "text_reader"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["audio"]
|
||||
|
||||
def setup(self):
|
||||
if self.post_processor is None:
|
||||
self.post_processor = "microsoft/speecht5_hifigan"
|
||||
super().setup()
|
||||
|
||||
def encode(self, text, speaker_embeddings=None):
|
||||
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
||||
|
||||
if speaker_embeddings is None:
|
||||
if not is_datasets_available():
|
||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
||||
|
||||
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
||||
|
||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model.generate_speech(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
51
src/transformers/tools/translation.py
Normal file
51
src/transformers/tools/translation.py
Normal file
@ -0,0 +1,51 @@
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
TRANSLATION_DESCRIPTION = (
|
||||
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should be "
|
||||
"the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, which "
|
||||
"should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in plain "
|
||||
"English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
|
||||
)
|
||||
|
||||
|
||||
class TranslationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TranslationTool
|
||||
|
||||
translator = TranslationTool()
|
||||
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
||||
description = TRANSLATION_DESCRIPTION
|
||||
name = "translator"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
# TODO add all other languages
|
||||
lang_to_code = {"French": "fra_Latn", "English": "eng_Latn", "Spanish": "spa_Latn"}
|
||||
|
||||
inputs = ["text", "text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text, src_lang, tgt_lang):
|
||||
if src_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{src_lang} is not a supported language.")
|
||||
if tgt_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{tgt_lang} is not a supported language.")
|
||||
src_lang = self.lang_to_code[src_lang]
|
||||
tgt_lang = self.lang_to_code[tgt_lang]
|
||||
return self.pre_processor._build_translation_inputs(
|
||||
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|
@ -108,6 +108,7 @@ from .import_utils import (
|
||||
is_datasets_available,
|
||||
is_decord_available,
|
||||
is_detectron2_available,
|
||||
is_diffusers_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
@ -121,6 +122,7 @@ from .import_utils import (
|
||||
is_natten_available,
|
||||
is_ninja_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_available,
|
||||
is_pandas_available,
|
||||
is_peft_available,
|
||||
|
@ -235,6 +235,7 @@ def try_to_load_from_cache(
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision if found.
|
||||
@ -251,6 +252,8 @@ def try_to_load_from_cache(
|
||||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
repo_type (`str`, *optional*):
|
||||
The type of the repo.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
@ -266,7 +269,9 @@ def try_to_load_from_cache(
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
|
||||
object_id = repo_id.replace("/", "--")
|
||||
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
|
||||
if repo_type is None:
|
||||
repo_type = "model"
|
||||
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
||||
if not os.path.isdir(repo_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
@ -303,6 +308,7 @@ def cached_file(
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
@ -342,6 +348,8 @@ def cached_file(
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -393,7 +401,7 @@ def cached_file(
|
||||
if _commit_hash is not None and not force_download:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
@ -410,6 +418,7 @@ def cached_file(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
|
@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_diffusers_available = importlib.util.find_spec("diffusers") is not None
|
||||
try:
|
||||
_diffusers_version = importlib_metadata.version("diffusers")
|
||||
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_diffusers_available = False
|
||||
|
||||
|
||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
||||
try:
|
||||
_detectron2_version = importlib_metadata.version("detectron2")
|
||||
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_opencv_available = importlib.util.find_spec("cv2") is not None
|
||||
|
||||
|
||||
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
|
||||
try:
|
||||
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
|
||||
@ -431,6 +442,10 @@ def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_opencv_available():
|
||||
return _opencv_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
@ -497,6 +512,10 @@ def is_datasets_available():
|
||||
return _datasets_available
|
||||
|
||||
|
||||
def is_diffusers_available():
|
||||
return _diffusers_available
|
||||
|
||||
|
||||
def is_detectron2_available():
|
||||
return _detectron2_available
|
||||
|
||||
|
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
42
tests/tools/document_question_answering.py
Normal file
42
tests/tools/document_question_answering.py
Normal file
@ -0,0 +1,42 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.tools import DocumentQuestionAnsweringTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = DocumentQuestionAnsweringTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
43
tests/tools/generative_question_answering.py
Normal file
43
tests/tools/generative_question_answering.py
Normal file
@ -0,0 +1,43 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.tools import GenerativeQuestionAnsweringTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class GenerativeQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = GenerativeQuestionAnsweringTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
40
tests/tools/image_captioning.py
Normal file
40
tests/tools/image_captioning.py
Normal file
@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
from transformers.tools import ImageCaptioningTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = ImageCaptioningTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
40
tests/tools/image_question_answering.py
Normal file
40
tests/tools/image_question_answering.py
Normal file
@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
from transformers.tools import ImageQuestionAnsweringTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = ImageQuestionAnsweringTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
40
tests/tools/image_segmentation.py
Normal file
40
tests/tools/image_segmentation.py
Normal file
@ -0,0 +1,40 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
from transformers.tools import ImageSegmentationTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = ImageSegmentationTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image=image, prompt="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
36
tests/tools/speech_to_text.py
Normal file
36
tests/tools/speech_to_text.py
Normal file
@ -0,0 +1,36 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.tools import SpeechToTextTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = SpeechToTextTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
94
tests/tools/test_tools_common.py
Normal file
94
tests/tools/test_tools_common.py
Normal file
@ -0,0 +1,94 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, str):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, Image.Image):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
34
tests/tools/text_classification.py
Normal file
34
tests/tools/text_classification.py
Normal file
@ -0,0 +1,34 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = TextClassificationTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
49
tests/tools/text_summarization.py
Normal file
49
tests/tools/text_summarization.py
Normal file
@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.tools import TextSummarizationTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = TextSummarizationTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
38
tests/tools/text_to_speech.py
Normal file
38
tests/tools/text_to_speech.py
Normal file
@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.tools import TextToSpeechTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = TextToSpeechTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("hey")
|
||||
print(result.shape)
|
||||
|
||||
# TODO check for real values
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool("hey")
|
||||
print(result.shape)
|
||||
|
||||
# TODO check for real values
|
44
tests/tools/translation.py
Normal file
44
tests/tools/translation.py
Normal file
@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.tools import TranslationTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = TranslationTool()
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
Reference in New Issue
Block a user