mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-12 01:04:36 +08:00
Compare commits
37 Commits
v4.50.3
...
test_compo
| Author | SHA1 | Date | |
|---|---|---|---|
| 6da21062ec | |||
| 95d8a42c37 | |||
| 8c0dcbab5f | |||
| 3876fce049 | |||
| 217fdcddc2 | |||
| 64fed7fa9c | |||
| 764c614c36 | |||
| 486cdee0a4 | |||
| 1cc69e9baf | |||
| d41e263a29 | |||
| 60c1765b4e | |||
| 7b9e474df6 | |||
| b424833664 | |||
| 074c78b1b9 | |||
| b10465c064 | |||
| 9107104c07 | |||
| 6cafc43ca1 | |||
| d9e90024fb | |||
| a07d3dca41 | |||
| 559407b002 | |||
| a71564ff9a | |||
| 8b52eec3f9 | |||
| e958eec82b | |||
| e04e4edf05 | |||
| f95048dc70 | |||
| 3a00255271 | |||
| 145ee7426d | |||
| 87803e821e | |||
| 05f5afcf99 | |||
| 1dd6bdad26 | |||
| 433100449f | |||
| 5128a792a5 | |||
| 73068409e9 | |||
| 905c205c10 | |||
| cd0241719b | |||
| 508683e43d | |||
| 7355f8b110 |
141
src/transformers/tools/README.md
Normal file
141
src/transformers/tools/README.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Do anything with Transformers
|
||||
|
||||
Transformers support all modalities and has many models performing many different types of tasks. But it can get confusing to mix and match them to solve the problem at hand, which is why we have developed a new API of **tools** and **agents**. Given a prompt in natural language and a set of tools, an agent will determine the right code to run with the tools and chain them properly to give you the result you expected.
|
||||
|
||||
Let's start with examples!
|
||||
|
||||
## Examples
|
||||
|
||||
First we need an agent, which is a fancy word to design a LLM tasked with writing the code you will need. We support the traditional openai LLMs but you should really try the opensource alternatives developed by the community which:
|
||||
- clearly state the data they have been trained on
|
||||
- you can run on your own cloud or hardware
|
||||
- have built-in versioning
|
||||
|
||||
<!--TODO for the release we should have a publicly available agent and if token is none, we grab the HF token-->
|
||||
|
||||
```py
|
||||
from transformers.tools import EndpointAgent
|
||||
|
||||
agent = EndpointAgent(
|
||||
url_endpoint=your_endpoint,
|
||||
token=your_hf_token,
|
||||
)
|
||||
|
||||
# from transformers.tools import OpenAiAgent
|
||||
|
||||
# agent = OpenAiAgent(api_key=your_openai_api_key)
|
||||
```
|
||||
|
||||
### Task 1: Classifying text in (almost) any language
|
||||
|
||||
Now to execute a given task, we need to pick a set of tools in `transformers` and send them to our agent. Let's say you want to classify a text in a non-English language, and you have trouble finding a model trained in that language. You can pick a translation tool and a standard text classification tool:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool, TranslationTool
|
||||
|
||||
tools = [TextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
|
||||
```
|
||||
|
||||
then you just run this by your agent:
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"Determine if the following `text` (in French) is positive or negative.",
|
||||
tools=tools,
|
||||
text="J'aime beaucoup Hugging Face!"
|
||||
)
|
||||
```
|
||||
|
||||
Note that you can send any additional inputs in a variable that you named in your prompt (between backticks because it helps the LLM). For text inputs, you can just put them in the prompt:
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"""Determine if the following text: "J'aime beaucoup Hugging Face!" (in French) is positive or negative.""",
|
||||
tools=tools,
|
||||
)
|
||||
```
|
||||
|
||||
In both cases, you should see the agent generate code using your set of tools that is then executed to provide you the answer you were looking for. Neat!
|
||||
|
||||
If you don't have the hardware to run the models translating and classifying the text, you can use the inference API by selecting a remote tool:
|
||||
|
||||
|
||||
```py
|
||||
from transformers.tools import RemoteTextClassificationTool, TranslationTool
|
||||
|
||||
tools = [RemoteTextClassificationTool(), TranslationTool(src_lang="fra_Latn", tgt_lang="eng_Latn")]
|
||||
|
||||
agent.run(
|
||||
"Determine if the following `text` (in French) is positive or negative.",
|
||||
tools=tools,
|
||||
text="J'aime beaucoup Hugging Face!"
|
||||
)
|
||||
```
|
||||
|
||||
This was still all text-based. Let's now get to something more exciting, combining vision and speech
|
||||
|
||||
## Example 2:
|
||||
|
||||
Let's say we want to hear out loud what is in a given image. There are models that do image-captioning in Transformers, and other models that generate speech from text, but how to combine them? Quite easily:
|
||||
|
||||
<!--TODO add the audio reader tool once it exists-->
|
||||
|
||||
```py
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers.tools import ImageCaptioningTool, TextToSpeechTool
|
||||
|
||||
tools = [ImageCaptioningTool(), TextToSpeechTool()]
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
speech = agent.run(
|
||||
"Tell me out loud what the `image` contains.",
|
||||
tools=tools,
|
||||
image=image
|
||||
)
|
||||
```
|
||||
|
||||
Note that here you have to pass your input as a separate variable since you can't really embed your image in the text.
|
||||
|
||||
In all those examples, we have been using the default checkpoint for a given tool, but you can specify the one you want! For instance, the image-captioning tool uses BLIP by default, but let's upgrade to BLIP-2
|
||||
|
||||
<!--TODO Once it works, use the inference API for BLIP-2 here as it's heavy-->
|
||||
|
||||
```py
|
||||
tools = [ImageCaptioningTool("Salesforce/blip2-opt-2.7b"), TextToSpeechTool()]
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
speech = agent.run(
|
||||
"Tell me out loud what the `image` contains.",
|
||||
tools=tools,
|
||||
image=image
|
||||
)
|
||||
```
|
||||
|
||||
Add more examples?
|
||||
|
||||
## How does it work ?
|
||||
|
||||
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the LLM to give a small sample of code performing a task with a set of tools. This prompt is then completed by the task you give your agent and the description of the tools you give it. This way it gets access to the doc of the tools you are using, especially their expected inputs and outputs and can generate the relevant code.
|
||||
|
||||
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools. Pipelines are more refactored and often combine several tasks in one. Tools are really meant to be focused one very simple task only.
|
||||
|
||||
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools. I hear you screaming "Arbitrary code execution!" in the back, but calm down a minute and let me explain.
|
||||
|
||||
The only functions that can be called are the tools you provided and the print function, so you're already limited in what can be executed. You should be safe if it's limited to Hugging Face tools. Then we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the `run()` method with the additional argument `return_code=True`, in which case the agent will just return the code to execute and you can decide whether to do it or not.
|
||||
|
||||
Note that LLMs are still not *that* good at producing the small amount of code to chain the tools, so we added some logic to fix typos during the evaluation: there are often misnamed variable names or dictionary keys.
|
||||
|
||||
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
|
||||
|
||||
## Future developments
|
||||
|
||||
We hope you're as excited by this new API as we are. Here are a few things we are thinking of adding next if we see the community is interested:
|
||||
- Make the agent pick the tools itself in a first step.
|
||||
- Make the run command more chat-based, so you can copy-paste any error message you see in a next step to have the LLM fix its code, or ask for some improvements.
|
||||
- Add support for more type of agents
|
||||
|
||||
12
src/transformers/tools/__init__.py
Normal file
12
src/transformers/tools/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .agents import Agent, EndpointAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool
|
||||
from .controlnet import ControlNetTool
|
||||
from .generative_question_answering import GenerativeQuestionAnsweringTool, RemoteGenerativeQuestionAnsweringTool
|
||||
from .image_captioning import ImageCaptioningTool, RemoteImageCaptioningTool
|
||||
from .image_segmentation import ImageSegmentationTool
|
||||
from .language_identifier import LanguageIdentificationTool
|
||||
from .speech_to_text import RemoteSpeechToTextTool, SpeechToTextTool
|
||||
from .stable_diffusion import StableDiffusionTool
|
||||
from .text_classification import RemoteTextClassificationTool, TextClassificationTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
||||
137
src/transformers/tools/agents.py
Normal file
137
src/transformers/tools/agents.py
Normal file
@ -0,0 +1,137 @@
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from .python_interpreter import evaluate
|
||||
|
||||
|
||||
# Move to util when this branch is ready to merge
|
||||
def is_openai_available():
|
||||
return importlib.util.find_spec("openai") is not None
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
import openai
|
||||
|
||||
# docstyle-ignore
|
||||
PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
Each instruction in Python should be a simple assignement.
|
||||
The final result should be stored in a variable named `result`. You can also print the result if it makes sense to do so.
|
||||
You should only use the tools necessary to perform the task.
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
Tools:
|
||||
- tool_0: This is a tool that translates text from French to English. It takes an input named `text` which should be the text in French and returns a dictionary with a single key `'translation_text'` that contains the translation in Enlish.
|
||||
- tool_1: This is a tool that generates speech from a given text in English. It takes an input named `text` which should be the text in English and returns the path to a filename containing an audio of this text read.
|
||||
- tool_2: This is a tool that answers question about images. It takes an input named `text` which should be the question in English and an input `image` which should be an image, and outputs a text that is the answer to the question.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = tool_0(text=question)['translation_text']
|
||||
result = tool_2(text=translated_question, image=image)
|
||||
print(f"The answer is {result}")
|
||||
```
|
||||
|
||||
Task: "Generate an image using the text given in the variable `caption`."
|
||||
|
||||
Tools:
|
||||
- tool_0: This is a tool that reads an English text out loud. It takes an input named `text` which whould contain the text to read (in English) and returns a waveform object containing the sound.
|
||||
- tool_1: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to caption, and returns a text that contains the description in English.
|
||||
- tool_2: This is a tool that creates an image according to a text description. It takes an input named `text` which contains the image description and outputs an image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
result = tool_2(text=caption)
|
||||
```
|
||||
|
||||
Task: "<<prompt>>"
|
||||
|
||||
Tools:
|
||||
<<tools>>
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
|
||||
class Agent:
|
||||
def run(self, task, tools, return_code=False, **kwargs):
|
||||
code = self.generate_code(task, tools)
|
||||
# Clean up the code received
|
||||
code_lines = code.split("\n")
|
||||
in_block_code = "```" in code_lines[0]
|
||||
additional_explanation = []
|
||||
if in_block_code:
|
||||
code_lines = code_lines[1:]
|
||||
for idx in range(len(code_lines)):
|
||||
if in_block_code and "```" in code_lines[idx]:
|
||||
additional_explanation = code_lines[idx + 1 :]
|
||||
code_lines = code_lines[:idx]
|
||||
break
|
||||
|
||||
clean_code = "\n".join(code_lines)
|
||||
|
||||
all_tools = {"print": print}
|
||||
all_tools.update({f"tool_{idx}": tool for idx, tool in enumerate(tools)})
|
||||
|
||||
print(f"==Code generated by the agent==\n{clean_code}\n\n")
|
||||
if len(additional_explanation) > 0:
|
||||
explanation = "\n".join(additional_explanation).strip()
|
||||
print(f"==Additional explanation from the agent==\n{explanation}\n\n")
|
||||
print("==Result==")
|
||||
|
||||
if not return_code:
|
||||
return evaluate(clean_code, all_tools, kwargs)
|
||||
else:
|
||||
return clean_code
|
||||
|
||||
|
||||
class EndpointAgent(Agent):
|
||||
def __init__(self, url_endpoint, token):
|
||||
self.url_endpoint = url_endpoint
|
||||
self.token = token
|
||||
|
||||
def generate_code(self, task, tools):
|
||||
headers = {"Authorization": self.token}
|
||||
tool_descs = [f"- tool_{i}: {tool.description}" for i, tool in enumerate(tools)]
|
||||
prompt = PROMPT_TEMPLATE.replace("<<prompt>>", task)
|
||||
prompt = prompt.replace("<<tools>>", "\n".join(tool_descs))
|
||||
inputs = {
|
||||
"inputs": prompt,
|
||||
"parameters": {"max_new_tokens": 200, "do_sample": True, "temperature": 0.5, "return_full_text": False},
|
||||
}
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json}")
|
||||
return response.json()[0]["generated_text"]
|
||||
|
||||
|
||||
class OpenAiAgent(Agent):
|
||||
def __init__(self, model="gpt-3.5-turbo", api_key=None):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAIAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
|
||||
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
|
||||
"xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
def generate_code(self, task, tools):
|
||||
tool_descs = [f"- tool_{i}: {tool.description}" for i, tool in enumerate(tools)]
|
||||
prompt = PROMPT_TEMPLATE.replace("<<prompt>>", task)
|
||||
prompt = prompt.replace("<<tools>>", "\n".join(tool_descs))
|
||||
|
||||
result = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
169
src/transformers/tools/base.py
Normal file
169
src/transformers/tools/base.py
Normal file
@ -0,0 +1,169 @@
|
||||
from typing import List
|
||||
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
from huggingface_hub import InferenceApi
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
Example of a super 'Tool' class that could live in huggingface_hub
|
||||
"""
|
||||
|
||||
description = "This is a tool that ..."
|
||||
is_initialized = False
|
||||
|
||||
inputs: List[str]
|
||||
outputs : List[str]
|
||||
name: str
|
||||
|
||||
def __call__(self, *args, **kwargs): # Might become run?
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def post_init(self):
|
||||
# Do here everything you need to execute after the init (to avoir overriding the init which is complex), such
|
||||
# as formatting the description with the attributes of your tools.
|
||||
pass
|
||||
|
||||
def setup(self):
|
||||
# Do here any operation that is expensive and needs to be executed before you start using your tool. Such as
|
||||
# loading a big model.
|
||||
self.is_initialized = True
|
||||
|
||||
|
||||
class RemoteTool(Tool):
|
||||
default_checkpoint = None
|
||||
description = "This is a tool that ..."
|
||||
|
||||
def __init__(self, repo_id=None):
|
||||
if repo_id is None:
|
||||
repo_id = self.default_checkpoint
|
||||
self.repo_id = repo_id
|
||||
self.client = InferenceApi(repo_id)
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
if len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
return {"data": args[0]}
|
||||
|
||||
return {"inputs": kwargs}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs)
|
||||
else:
|
||||
outputs = self.client(inputs)
|
||||
return self.extract_outputs(outputs)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
description = "This is a tool that ..."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
||||
self.model = model
|
||||
self.pre_processor = pre_processor
|
||||
self.post_processor = post_processor
|
||||
self.device = device
|
||||
self.device_map = device_map
|
||||
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
||||
if device_map is not None:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.hub_kwargs = hub_kwargs
|
||||
|
||||
self.is_initialized = False
|
||||
self.post_init()
|
||||
|
||||
def setup(self):
|
||||
# Instantiate me maybe
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
self.device = list(self.model.hf_device_map.values())[0]
|
||||
else:
|
||||
self.device = PartialState().default_device
|
||||
|
||||
if self.device_map is None:
|
||||
self.model.to(self.device)
|
||||
|
||||
def post_init(self):
|
||||
pass
|
||||
|
||||
def encode(self, raw_inputs):
|
||||
return self.pre_processor(raw_inputs)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.post_processor(outputs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||
outputs = self.forward(encoded_inputs)
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
return self.decode(outputs)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError('Gradio should be installed in order to launch a gradio demo.')
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
inputs=tool_class.inputs,
|
||||
outputs=tool_class.outputs,
|
||||
title=tool_class.__name__,
|
||||
article=tool.description
|
||||
).launch()
|
||||
97
src/transformers/tools/controlnet.py
Normal file
97
src/transformers/tools/controlnet.py
Normal file
@ -0,0 +1,97 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.tools.base import Tool
|
||||
from transformers.utils import (
|
||||
is_accelerate_available,
|
||||
is_diffusers_available,
|
||||
is_opencv_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if is_diffusers_available():
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
|
||||
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
|
||||
class ControlNetTool(Tool):
|
||||
default_stable_diffusion_checkpoint = "runwayml/stable-diffusion-v1-5"
|
||||
default_controlnet_checkpoint = "lllyasviel/sd-controlnet-canny"
|
||||
|
||||
description = """
|
||||
This is a tool that transforms an image according to a prompt. It takes two inputs:
|
||||
first, the image that will be transformers, and second: the prompt (or text description) that will be used.
|
||||
It returns a modified image.
|
||||
"""
|
||||
|
||||
def __init__(self, device=None, controlnet=None, stable_diffusion=None) -> None:
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError('Accelerate should be installed in order to use tools.')
|
||||
if not is_diffusers_available():
|
||||
raise ImportError('Diffusers should be installed in order to use the StableDiffusionTool.')
|
||||
if not is_vision_available():
|
||||
raise ImportError('Pillow should be installed in order to use the StableDiffusionTool.')
|
||||
if not is_opencv_available():
|
||||
raise ImportError('opencv should be installed in order to use the StableDiffusionTool.')
|
||||
|
||||
super().__init__()
|
||||
|
||||
if controlnet is None:
|
||||
controlnet = self.default_controlnet_checkpoint
|
||||
self.controlnet_checkpoint = controlnet
|
||||
|
||||
if stable_diffusion is None:
|
||||
stable_diffusion = self.default_stable_diffusion_checkpoint
|
||||
self.stable_diffusion_checkpoint = stable_diffusion
|
||||
|
||||
self.device = device
|
||||
|
||||
|
||||
def setup(self):
|
||||
if self.device is None:
|
||||
self.device = PartialState().default_device
|
||||
|
||||
self.controlnet = ControlNetModel.from_pretrained(self.controlnet_checkpoint, torch_dtype=torch.float16)
|
||||
self.pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
self.stable_diffusion_checkpoint, controlnet=self.controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
self.is_initialized = True
|
||||
|
||||
def __call__(self, image, prompt):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
initial_prompt = 'super-hero character, best quality, extremely detailed'
|
||||
prompt = initial_prompt + prompt
|
||||
|
||||
low_threshold = 100
|
||||
high_threshold = 200
|
||||
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, low_threshold, high_threshold)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(2)
|
||||
|
||||
return self.pipeline(
|
||||
prompt,
|
||||
canny_image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
48
src/transformers/tools/generative_question_answering.py
Normal file
48
src/transformers/tools/generative_question_answering.py
Normal file
@ -0,0 +1,48 @@
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool, RemoteTool
|
||||
|
||||
|
||||
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
|
||||
|
||||
Can you answer this question about the text: '{question}'"""
|
||||
|
||||
|
||||
class GenerativeQuestionAnsweringTool(PipelineTool):
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
description = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
|
||||
def encode(self, text: str, question: str):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return self.pre_processor(prompt, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
output_ids = self.model.generate(**inputs)
|
||||
|
||||
in_b, _ = inputs["input_ids"].shape
|
||||
out_b = output_ids.shape[0]
|
||||
|
||||
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
|
||||
class RemoteGenerativeQuestionAnsweringTool(RemoteTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs[0]["generated_text"]
|
||||
|
||||
def prepare_inputs(self, text, question):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return prompt
|
||||
57
src/transformers/tools/image_captioning.py
Normal file
57
src/transformers/tools/image_captioning.py
Normal file
@ -0,0 +1,57 @@
|
||||
import io
|
||||
|
||||
from ..models.auto import AutoModelForVision2Seq, AutoProcessor
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool, RemoteTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningTool(PipelineTool):
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVision2Seq
|
||||
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
|
||||
inputs = ['image']
|
||||
outputs = ['text']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the ImageCaptioningTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image"):
|
||||
return self.pre_processor(images=image, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
|
||||
|
||||
class RemoteImageCaptioningTool(RemoteTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-large"
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs[0]["generated_text"]
|
||||
|
||||
def prepare_inputs(self, image):
|
||||
if isinstance(image, bytes):
|
||||
return {"data": image}
|
||||
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
return {"data": byte_io.getvalue()}
|
||||
41
src/transformers/tools/image_segmentation.py
Normal file
41
src/transformers/tools/image_segmentation.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation, is_vision_available
|
||||
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = CLIPSegForImageSegmentation
|
||||
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ImportError('Pillow should be installed in order to use the StableDiffusionTool.')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, texts: List[str], image: "Image"):
|
||||
return self.pre_processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
|
||||
|
||||
def decode(self, outputs):
|
||||
logits_array = outputs.logits
|
||||
segmentation_maps = []
|
||||
|
||||
for logits in logits_array:
|
||||
array = logits.cpu().detach().numpy()
|
||||
array[array < 0] = 0
|
||||
array[array >= 0] = 1
|
||||
segmentation_maps.append(array)
|
||||
|
||||
return segmentation_maps
|
||||
23
src/transformers/tools/language_identifier.py
Normal file
23
src/transformers/tools/language_identifier.py
Normal file
@ -0,0 +1,23 @@
|
||||
from .text_classification import TextClassificationTool
|
||||
|
||||
|
||||
class LanguageIdentificationTool(TextClassificationTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import LanguageIdentificationTool
|
||||
|
||||
classifier = LanguageIdentificationTool()
|
||||
classifier("This is a super nice API!")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "papluca/xlm-roberta-base-language-detection"
|
||||
description = (
|
||||
"This is a tool that identifies the language of the text passed as input. It takes one input named `text` and "
|
||||
"returns the two-letter label of the identified language."
|
||||
)
|
||||
|
||||
def decode(self, outputs):
|
||||
return super().decode(outputs)["label"]
|
||||
201
src/transformers/tools/python_interpreter.py
Normal file
201
src/transformers/tools/python_interpreter.py
Normal file
@ -0,0 +1,201 @@
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def evaluate(code: str, tools: Dict[str, Callable], variables=None):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
print("The code generated by the agent is not valid.\n", e)
|
||||
return
|
||||
state = {} if variables is None else variables.copy()
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
print(
|
||||
f"Evaluation of the code stopped at line {idx} before the end because of the following error:\n{e}\n"
|
||||
)
|
||||
break
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
if "result" in state:
|
||||
return state["result"]
|
||||
for key in state:
|
||||
if "result" in key:
|
||||
return state[key]
|
||||
|
||||
print("No result found, returning the current state.")
|
||||
return state
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
evaluate_if(expression, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([evaluate_ast(v, state, tools) for v in expression.values])
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
|
||||
if len(var_names) == 1:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
|
||||
f"type {type(call.func)}."
|
||||
)
|
||||
func_name = call.func.id
|
||||
if func_name not in tools:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
func = tools[func_name]
|
||||
# Todo deal with args
|
||||
args = {evaluate_ast(arg, state, tools) for arg in call.args}
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpretorError(f"Could not index {value} with {index}.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"Could not find a variable named {name.id} in the state.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpretorError("Cannot evaluate conditions with multiple operators")
|
||||
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparator = condition.ops[0]
|
||||
right = evaluate_ast(condition.comparators[0], state, tools)
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {comparator}")
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
if evaluate_condition(if_statement.test, state, tools):
|
||||
for line in if_statement.body:
|
||||
evaluate_ast(line, state, tools)
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
evaluate_ast(line, state, tools)
|
||||
36
src/transformers/tools/speech_to_text.py
Normal file
36
src/transformers/tools/speech_to_text.py
Normal file
@ -0,0 +1,36 @@
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .base import PipelineTool, RemoteTool
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
description = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt").input_features
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs=inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
class RemoteSpeechToTextTool(RemoteTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
return outputs["text"]
|
||||
|
||||
def prepare_inputs(self, audio):
|
||||
return {"data": audio}
|
||||
43
src/transformers/tools/stable_diffusion.py
Normal file
43
src/transformers/tools/stable_diffusion.py
Normal file
@ -0,0 +1,43 @@
|
||||
from transformers.tools.base import Tool
|
||||
from transformers.utils import is_accelerate_available, is_diffusers_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if is_diffusers_available():
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class StableDiffusionTool(Tool):
|
||||
default_checkpoint = "runwayml/stable-diffusion-v1-5"
|
||||
description = (
|
||||
"This is a tool that creates an image according to a text description. It takes an input named `text` which "
|
||||
"contains the image description and outputs an image."
|
||||
)
|
||||
|
||||
def __init__(self, device=None) -> None:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError('Accelerate should be installed in order to use tools.')
|
||||
if not is_diffusers_available():
|
||||
raise ImportError('Diffusers should be installed in order to use the StableDiffusionTool.')
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.pipeline = None
|
||||
|
||||
def setup(self):
|
||||
if self.device is None:
|
||||
self.device = PartialState().default_device
|
||||
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
|
||||
self.pipeline.to(self.device)
|
||||
|
||||
self.is_initialized = True
|
||||
|
||||
def __call__(self, text):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
return self.pipeline(text).images[0]
|
||||
101
src/transformers/tools/text_classification.py
Normal file
101
src/transformers/tools/text_classification.py
Normal file
@ -0,0 +1,101 @@
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
|
||||
from .base import PipelineTool, RemoteTool
|
||||
|
||||
|
||||
class TextClassificationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
classifier = TextClassificationTool("distilbert-base-uncased-finetuned-sst-2-english")
|
||||
classifier("This is a super nice API!")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" # Needs to be updated
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
description = (
|
||||
"classifies an English text using the following {n_labels} labels: {labels}. It takes a input named `text` "
|
||||
"which should be in English and returns a dictionary with two keys named 'label' (the predicted label ) and "
|
||||
"'score' (the probability associated to it)."
|
||||
)
|
||||
|
||||
def post_init(self):
|
||||
if isinstance(self.model, str):
|
||||
config = AutoConfig.from_pretrained(self.model)
|
||||
else:
|
||||
config = self.model.config
|
||||
|
||||
num_labels = config.num_labels
|
||||
labels = list(config.label2id.keys())
|
||||
|
||||
if len(labels) > 1:
|
||||
labels = [f"'{label}'" for label in labels]
|
||||
labels_string = ", ".join(labels[:-1])
|
||||
labels_string += f", and {labels[-1]}"
|
||||
else:
|
||||
raise ValueError("Not enough labels.")
|
||||
|
||||
self.description = self.description.replace("{n_labels}", str(num_labels)).replace("{labels}", labels_string)
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor(text, return_tensors="pt")
|
||||
|
||||
def decode(self, outputs):
|
||||
logits = outputs.logits
|
||||
scores = torch.nn.functional.softmax(logits, dim=-1)
|
||||
label_id = torch.argmax(logits[0]).item()
|
||||
label = self.model.config.id2label[label_id]
|
||||
return {"label": label, "score": scores[0][label_id].item()}
|
||||
|
||||
|
||||
class RemoteTextClassificationTool(RemoteTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import RemoteTextClassificationTool
|
||||
|
||||
classifier = RemoteTextClassificationTool("distilbert-base-uncased-finetuned-sst-2-english")
|
||||
classifier("This is a super nice API!")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" # Needs to be updated
|
||||
description = (
|
||||
"classifies an English text using the following {n_labels} labels: {labels}. It takes a input named `text` "
|
||||
"which should be in English and returns a dictionary with two keys named 'label' (the predicted label ) and "
|
||||
"'score' (the probability associated to it)."
|
||||
)
|
||||
|
||||
def post_init(self):
|
||||
config = AutoConfig.from_pretrained(self.repo_id)
|
||||
num_labels = config.num_labels
|
||||
labels = list(config.label2id.keys())
|
||||
|
||||
if len(labels) > 1:
|
||||
labels = [f"'{label}'" for label in labels]
|
||||
labels_string = ", ".join(labels[:-1])
|
||||
labels_string += f", and {labels[-1]}"
|
||||
else:
|
||||
raise ValueError("Not enough labels.")
|
||||
|
||||
self.description = self.description.replace("{n_labels}", str(num_labels)).replace("{labels}", labels_string)
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
label = None
|
||||
max_score = 0
|
||||
for result in outputs:
|
||||
lbl = result["label"]
|
||||
score = float(result["score"])
|
||||
if score > max_score:
|
||||
label = lbl
|
||||
max_score = score
|
||||
|
||||
return {"label": label, "score": max_score}
|
||||
46
src/transformers/tools/text_to_speech.py
Normal file
46
src/transformers/tools/text_to_speech.py
Normal file
@ -0,0 +1,46 @@
|
||||
import torch
|
||||
|
||||
from transformers.utils import is_datasets_available
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
description = (
|
||||
"This is a tool that reads an English text out loud. It takes an input named `text` which whould contain the "
|
||||
"text to read (in English) and returns a waveform object containing the sound."
|
||||
)
|
||||
|
||||
def post_init(self):
|
||||
if self.post_processor is None:
|
||||
self.post_processor = "microsoft/speecht5_hifigan"
|
||||
|
||||
def encode(self, text, speaker_embeddings=None):
|
||||
inputs = self.pre_processor(text=text, return_tensors="pt")
|
||||
|
||||
if speaker_embeddings is None:
|
||||
if not is_datasets_available():
|
||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
||||
|
||||
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
||||
|
||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model.generate_speech(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
||||
65
src/transformers/tools/translation.py
Normal file
65
src/transformers/tools/translation.py
Normal file
@ -0,0 +1,65 @@
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TranslationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TranslationTool
|
||||
|
||||
translator = TranslationTool("distilbert-base-uncased-finetuned-sst-2-english")
|
||||
translator("This is a super nice API!")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
description = (
|
||||
"translates text from {src_lang} to {tgt_lang}. It takes an input named `text` which should be the text in {src_lang} "
|
||||
"and returns a dictionary with a single key `'translated_text'` that contains the translation in {tgt_lang}."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
super().__init__(
|
||||
model=model,
|
||||
pre_processor=pre_processor,
|
||||
post_processor=post_processor,
|
||||
device=device,
|
||||
device_map=device_map,
|
||||
model_kwargs=model_kwargs,
|
||||
**hub_kwargs,
|
||||
)
|
||||
|
||||
def post_init(self):
|
||||
codes_to_lang = {"fra_Latn": "French", "eng_Latn": "English"}
|
||||
src_lang = codes_to_lang[self.src_lang]
|
||||
tgt_lang = codes_to_lang[self.tgt_lang]
|
||||
self.description = self.description.replace("{src_lang}", src_lang).replace("{tgt_lang}", tgt_lang)
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor._build_translation_inputs(
|
||||
text, return_tensors="pt", src_lang=self.src_lang, tgt_lang=self.tgt_lang
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return {"translated_text": self.post_processor.decode(outputs[0].tolist())}
|
||||
@ -107,6 +107,7 @@ from .import_utils import (
|
||||
is_datasets_available,
|
||||
is_decord_available,
|
||||
is_detectron2_available,
|
||||
is_diffusers_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
@ -120,6 +121,7 @@ from .import_utils import (
|
||||
is_natten_available,
|
||||
is_ninja_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_pandas_available,
|
||||
is_phonemizer_available,
|
||||
is_protobuf_available,
|
||||
|
||||
@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_diffusers_available = importlib.util.find_spec("diffusers") is not None
|
||||
try:
|
||||
_diffusers_version = importlib_metadata.version("diffusers")
|
||||
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_diffusers_available = False
|
||||
|
||||
|
||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
||||
try:
|
||||
_detectron2_version = importlib_metadata.version("detectron2")
|
||||
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_opencv_available = importlib.util.find_spec("cv2") is not None
|
||||
|
||||
|
||||
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
|
||||
try:
|
||||
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
|
||||
@ -426,6 +437,9 @@ def is_tf2onnx_available():
|
||||
def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
def is_opencv_available():
|
||||
return _opencv_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
@ -493,6 +507,10 @@ def is_datasets_available():
|
||||
return _datasets_available
|
||||
|
||||
|
||||
def is_diffusers_available():
|
||||
return _diffusers_available
|
||||
|
||||
|
||||
def is_detectron2_available():
|
||||
return _detectron2_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user