mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
20 Commits
fix_genera
...
v4.29.1
Author | SHA1 | Date | |
---|---|---|---|
118e981068 | |||
37a508cd86 | |||
d2e5aedfb6 | |||
df98769c7a | |||
3652e1665b | |||
b1189abc99 | |||
10eeb2adf6 | |||
6376aa2f2d | |||
48722778ea | |||
15f260a82f | |||
60fc8f8dcf | |||
bcf9100975 | |||
9dd6209c9f | |||
9d5b0e50f9 | |||
7d415ba37c | |||
bb57271ed6 | |||
2a2be57697 | |||
d5e1c98120 | |||
fee5b5efbe | |||
d30849f732 |
@ -166,7 +166,6 @@ jobs:
|
||||
- v0.6-repository_consistency
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[all,quality]
|
||||
- run: pip install pytest
|
||||
- save_cache:
|
||||
key: v0.5-repository_consistency-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
@ -51,8 +51,6 @@ class CircleCIJob:
|
||||
resource_class: Optional[str] = "xlarge"
|
||||
tests_to_run: Optional[List[str]] = None
|
||||
working_directory: str = "~/transformers"
|
||||
# This should be only used for doctest job!
|
||||
command_timeout: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Deal with defaults for mutable attributes.
|
||||
@ -109,15 +107,11 @@ class CircleCIJob:
|
||||
steps.append({"store_artifacts": {"path": "~/transformers/installed.txt"}})
|
||||
|
||||
all_options = {**COMMON_PYTEST_OPTIONS, **self.pytest_options}
|
||||
pytest_flags = [f"--{key}={value}" if (value is not None or key in ["doctest-modules"]) else f"-{key}" for key, value in all_options.items()]
|
||||
pytest_flags = [f"--{key}={value}" if value is not None else f"-{key}" for key, value in all_options.items()]
|
||||
pytest_flags.append(
|
||||
f"--make-reports={self.name}" if "examples" in self.name else f"--make-reports=tests_{self.name}"
|
||||
)
|
||||
test_command = ""
|
||||
if self.command_timeout:
|
||||
test_command = f"timeout {self.command_timeout} "
|
||||
test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags)
|
||||
|
||||
test_command = f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags)
|
||||
if self.parallelism == 1:
|
||||
if self.tests_to_run is None:
|
||||
test_command += " << pipeline.parameters.tests_to_run >>"
|
||||
@ -167,37 +161,12 @@ class CircleCIJob:
|
||||
steps.append({"store_artifacts": {"path": "~/transformers/tests.txt"}})
|
||||
steps.append({"store_artifacts": {"path": "~/transformers/splitted_tests.txt"}})
|
||||
|
||||
test_command = ""
|
||||
if self.timeout:
|
||||
test_command = f"timeout {self.timeout} "
|
||||
test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags)
|
||||
test_command = f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags)
|
||||
test_command += " $(cat splitted_tests.txt)"
|
||||
if self.marker is not None:
|
||||
test_command += f" -m {self.marker}"
|
||||
|
||||
if self.name == "pr_documentation_tests":
|
||||
# can't use ` | tee tee tests_output.txt` as usual
|
||||
test_command += " > tests_output.txt"
|
||||
# Save the return code, so we can check if it is timeout in the next step.
|
||||
test_command += '; touch "$?".txt'
|
||||
# Never fail the test step for the doctest job. We will check the results in the next step, and fail that
|
||||
# step instead if the actual test failures are found. This is to avoid the timeout being reported as test
|
||||
# failure.
|
||||
test_command = f"({test_command}) || true"
|
||||
else:
|
||||
test_command += " | tee tests_output.txt"
|
||||
test_command += " | tee tests_output.txt"
|
||||
steps.append({"run": {"name": "Run tests", "command": test_command}})
|
||||
|
||||
# return code `124` means the previous (pytest run) step is timeout
|
||||
if self.name == "pr_documentation_tests":
|
||||
checkout_doctest_command = 'if [ -s reports/tests_pr_documentation_tests/failures_short.txt ]; '
|
||||
checkout_doctest_command += 'then echo "some test failed"; '
|
||||
checkout_doctest_command += 'cat reports/tests_pr_documentation_tests/failures_short.txt; '
|
||||
checkout_doctest_command += 'cat reports/tests_pr_documentation_tests/summary_short.txt; exit -1; '
|
||||
checkout_doctest_command += 'elif [ -s reports/tests_pr_documentation_tests/stats.txt ]; then echo "All tests pass!"; '
|
||||
checkout_doctest_command += 'elif [ -f 124.txt ]; then echo "doctest timeout!"; else echo "other fatal error)"; exit -1; fi;'
|
||||
steps.append({"run": {"name": "Check doctest results", "command": checkout_doctest_command}})
|
||||
|
||||
steps.append({"store_artifacts": {"path": "~/transformers/tests_output.txt"}})
|
||||
steps.append({"store_artifacts": {"path": "~/transformers/reports"}})
|
||||
job["steps"] = steps
|
||||
@ -432,51 +401,6 @@ repo_utils_job = CircleCIJob(
|
||||
tests_to_run="tests/repo_utils",
|
||||
)
|
||||
|
||||
# At this moment, only the files that are in `utils/documentation_tests.txt` will be kept (together with a dummy file).
|
||||
py_command = 'import os; import json; fp = open("pr_documentation_tests.txt"); data_1 = fp.read().strip().split("\\n"); fp = open("utils/documentation_tests.txt"); data_2 = fp.read().strip().split("\\n"); to_test = [x for x in data_1 if x in set(data_2)] + ["dummy.py"]; to_test = " ".join(to_test); print(to_test)'
|
||||
py_command = f"$(python3 -c '{py_command}')"
|
||||
command = f'echo "{py_command}" > pr_documentation_tests_filtered.txt'
|
||||
doc_test_job = CircleCIJob(
|
||||
"pr_documentation_tests",
|
||||
additional_env={"TRANSFORMERS_VERBOSITY": "error", "DATASETS_VERBOSITY": "error", "SKIP_CUDA_DOCTEST": "1"},
|
||||
install_steps=[
|
||||
"sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time",
|
||||
"pip install --upgrade pip",
|
||||
"pip install -e .[dev]",
|
||||
"pip install git+https://github.com/huggingface/accelerate",
|
||||
"pip install --upgrade pytest pytest-sugar",
|
||||
"find -name __pycache__ -delete",
|
||||
"find . -name \*.pyc -delete",
|
||||
# Add an empty file to keep the test step running correctly even no file is selected to be tested.
|
||||
"touch dummy.py",
|
||||
{
|
||||
"name": "Get files to test",
|
||||
"command":
|
||||
"git remote add upstream https://github.com/huggingface/transformers.git && git fetch upstream \n"
|
||||
"git diff --name-only --relative --diff-filter=AMR refs/remotes/upstream/main...HEAD | grep -E '\.(py|mdx)$' | grep -Ev '^\..*|/\.' | grep -Ev '__' > pr_documentation_tests.txt"
|
||||
},
|
||||
{
|
||||
"name": "List files beings changed: pr_documentation_tests.txt",
|
||||
"command":
|
||||
"cat pr_documentation_tests.txt"
|
||||
},
|
||||
{
|
||||
"name": "Filter pr_documentation_tests.txt",
|
||||
"command":
|
||||
command
|
||||
},
|
||||
{
|
||||
"name": "List files beings tested: pr_documentation_tests_filtered.txt",
|
||||
"command":
|
||||
"cat pr_documentation_tests_filtered.txt"
|
||||
},
|
||||
],
|
||||
tests_to_run="$(cat pr_documentation_tests_filtered.txt)", # noqa
|
||||
pytest_options={"-doctest-modules": None, "doctest-glob": "*.mdx", "dist": "loadfile", "rvsA": None},
|
||||
command_timeout=1200, # test cannot run longer than 1200 seconds
|
||||
pytest_num_workers=1,
|
||||
)
|
||||
|
||||
REGULAR_TESTS = [
|
||||
torch_and_tf_job,
|
||||
torch_and_flax_job,
|
||||
@ -487,7 +411,6 @@ REGULAR_TESTS = [
|
||||
hub_job,
|
||||
onnx_job,
|
||||
exotic_models_job,
|
||||
doc_test_job
|
||||
]
|
||||
EXAMPLES_TESTS = [
|
||||
examples_torch_job,
|
||||
|
8
.github/workflows/doctests.yml
vendored
8
.github/workflows/doctests.yml
vendored
@ -37,10 +37,18 @@ jobs:
|
||||
- name: Show installed libraries and their versions
|
||||
run: pip freeze
|
||||
|
||||
- name: Prepare files for doctests
|
||||
run: |
|
||||
python3 utils/prepare_for_doc_test.py src docs
|
||||
|
||||
- name: Run doctests
|
||||
run: |
|
||||
python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"
|
||||
|
||||
- name: Clean files after doctests
|
||||
run: |
|
||||
python3 utils/prepare_for_doc_test.py src docs --remove_new_line
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
continue-on-error: true
|
||||
|
8
Makefile
8
Makefile
@ -47,10 +47,10 @@ repo-consistency:
|
||||
# this target runs checks on all files
|
||||
|
||||
quality:
|
||||
black --check $(check_dirs) setup.py conftest.py
|
||||
black --check $(check_dirs) setup.py
|
||||
python utils/custom_init_isort.py --check_only
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
ruff $(check_dirs) setup.py conftest.py
|
||||
ruff $(check_dirs) setup.py
|
||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
python utils/check_doc_toc.py
|
||||
|
||||
@ -65,8 +65,8 @@ extra_style_checks:
|
||||
# this target runs checks on all files and potentially modifies some of them
|
||||
|
||||
style:
|
||||
black $(check_dirs) setup.py conftest.py
|
||||
ruff $(check_dirs) setup.py conftest.py --fix
|
||||
black $(check_dirs) setup.py
|
||||
ruff $(check_dirs) setup.py --fix
|
||||
${MAKE} autogenerate_code
|
||||
${MAKE} extra_style_checks
|
||||
|
||||
|
@ -341,7 +341,7 @@ Current number of checkpoints: ** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab.
|
||||
1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela.
|
||||
1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
|
||||
1. **[FocalNet](https://huggingface.co/docs/transformers/main/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
|
||||
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
|
||||
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
|
||||
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
|
||||
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
|
||||
@ -400,7 +400,7 @@ Current number of checkpoints: ** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
|
||||
1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
|
||||
1. **[OneFormer](https://huggingface.co/docs/transformers/model_doc/oneformer)** (from SHI Labs) released with the paper [OneFormer: One Transformer to Rule Universal Image Segmentation](https://arxiv.org/abs/2211.06220) by Jitesh Jain, Jiachen Li, MangTik Chiu, Ali Hassani, Nikita Orlov, Humphrey Shi.
|
||||
1. **[OpenLlama](https://huggingface.co/docs/transformers/main/model_doc/open-llama)** (from [s-JoL](https://huggingface.co/s-JoL)) released in [Open-Llama](https://github.com/s-JoL/Open-Llama).
|
||||
1. **[OpenLlama](https://huggingface.co/docs/transformers/model_doc/open-llama)** (from [s-JoL](https://huggingface.co/s-JoL)) released in [Open-Llama](https://github.com/s-JoL/Open-Llama).
|
||||
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
|
||||
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
@ -422,9 +422,9 @@ Current number of checkpoints: ** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
1. **[SEW-D](https://huggingface.co/docs/transformers/model_doc/sew_d)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
|
||||
|
13
conftest.py
13
conftest.py
@ -20,10 +20,6 @@ import sys
|
||||
import warnings
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
import _pytest
|
||||
|
||||
from transformers.utils.doctest_utils import HfDoctestModule, HfDocTestParser
|
||||
|
||||
|
||||
# allow having multiple repository checkouts and not needing to remember to rerun
|
||||
# 'pip install -e .[dev]' when switching between checkouts and running tests.
|
||||
@ -42,9 +38,12 @@ def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
|
||||
)
|
||||
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pipeline_test: mark test to run only when pipelines are tested"
|
||||
)
|
||||
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
|
||||
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
|
||||
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
@ -68,7 +67,7 @@ def pytest_sessionfinish(session, exitstatus):
|
||||
|
||||
|
||||
# Doctest custom flag to ignore output.
|
||||
IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT")
|
||||
IGNORE_RESULT = doctest.register_optionflag('IGNORE_RESULT')
|
||||
|
||||
OutputChecker = doctest.OutputChecker
|
||||
|
||||
@ -81,5 +80,3 @@ class CustomOutputChecker(OutputChecker):
|
||||
|
||||
|
||||
doctest.OutputChecker = CustomOutputChecker
|
||||
_pytest.doctest.DoctestModule = HfDoctestModule
|
||||
doctest.DocTestParser = HfDocTestParser
|
||||
|
@ -21,6 +21,8 @@
|
||||
title: Set up distributed training with 🤗 Accelerate
|
||||
- local: model_sharing
|
||||
title: Share your model
|
||||
- local: transformers_agents
|
||||
title: Agents
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
@ -99,6 +101,8 @@
|
||||
title: Notebooks with examples
|
||||
- local: community
|
||||
title: Community resources
|
||||
- local: custom_tools
|
||||
title: Custom Tools and Prompts
|
||||
- local: troubleshooting
|
||||
title: Troubleshoot
|
||||
title: Developer guides
|
||||
@ -179,6 +183,8 @@
|
||||
title: Conceptual guides
|
||||
- sections:
|
||||
- sections:
|
||||
- local: main_classes/agent
|
||||
title: Agents and Tools
|
||||
- local: model_doc/auto
|
||||
title: Auto Classes
|
||||
- local: main_classes/callback
|
||||
|
778
docs/source/en/custom_tools.mdx
Normal file
778
docs/source/en/custom_tools.mdx
Normal file
@ -0,0 +1,778 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Tools and Prompts
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are not aware of what tools and agents are in the context of transformers, we recommend you read the
|
||||
[Transformers Agents](transformers_agents) page first.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API that is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Creating and using custom tools and prompts is paramount to empowering the agent and having it perform new tasks.
|
||||
In this guide we'll take a look at:
|
||||
|
||||
- How to customize the prompt
|
||||
- How to use custom tools
|
||||
- How to create custom tools
|
||||
|
||||
## Customizing the prompt
|
||||
|
||||
As explained in [Transformers Agents](transformers_agents) agents can run in [`~Agent.run`] and [`~Agent.chat`] mode.
|
||||
Both the `run` and `chat` modes underlie the same logic. The language model powering the agent is conditioned on a long
|
||||
prompt and completes the prompt by generating the next tokens until the stop token is reached.
|
||||
The only difference between the two modes is that during the `chat` mode the prompt is extended with
|
||||
previous user inputs and model generations. This allows the agent to have access to past interactions,
|
||||
seemingly giving the agent some kind of memory.
|
||||
|
||||
### Structure of the prompt
|
||||
|
||||
Let's take a closer look at how the prompt is structured to understand how it can be best customized.
|
||||
The prompt is structured broadly into four parts.
|
||||
|
||||
- 1. Introduction: how the agent should behave, explanation of the concept of tools.
|
||||
- 2. Description of all the tools. This is defined by a `<<all_tools>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
|
||||
- 3. A set of examples of tasks and their solution
|
||||
- 4. Current example, and request for solution.
|
||||
|
||||
To better understand each part, let's look at a shortened version of how the `run` prompt can look like:
|
||||
|
||||
````text
|
||||
I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
[...]
|
||||
You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
- image_captioner: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to the caption and returns a text that contains the description in English.
|
||||
[...]
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
[...]
|
||||
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
````
|
||||
|
||||
The introduction (the text before *"Tools:"*) explains precisely how the model shall behave and what it should do.
|
||||
This part most likely does not need to be customized as the agent shall always behave the same way.
|
||||
|
||||
The second part (the bullet points below *"Tools"*) is dynamically added upon calling `run` or `chat`. There are
|
||||
exactly as many bullet points as there are tools in `agent.toolbox` and each bullet point consists of the name
|
||||
and description of the tool:
|
||||
|
||||
```text
|
||||
- <tool.name>: <tool.description>
|
||||
```
|
||||
|
||||
Let's verify this quickly by loading the document_qa tool and printing out the name and description.
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
document_qa = load_tool("document-question-answering")
|
||||
print(f"- {document_qa.name}: {document_qa.description}")
|
||||
```
|
||||
|
||||
which gives:
|
||||
```text
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
```
|
||||
|
||||
We can see that the tool name is short and precise. The description includes two parts, the first explaining
|
||||
what the tool does and the second states what input arguments and return values are expected.
|
||||
|
||||
A good tool name and tool description are very important for the agent to correctly use it. Note that the only
|
||||
information the agent has about the tool is its name and description, so one should make sure that both
|
||||
are precisely written and match the style of the existing tools in the toolbox. In particular make sure the description
|
||||
mentions all the arguments expected by name in code-style, along with the expected type and a description of what they
|
||||
are.
|
||||
|
||||
<Tip>
|
||||
|
||||
Check the naming and description of the curated Transformers tools to better understand what name and
|
||||
description a tool is expected to have. You can see all tools with the [`Agent.toolbox`] property.
|
||||
|
||||
</Tip>
|
||||
|
||||
The third part includes a set of curated examples that show the agent exactly what code it should produce
|
||||
for what kind of user request. The large language models empowering the agent are extremely good at
|
||||
recognizing patterns in a prompt and repeating the pattern with new data. Therefore, it is very important
|
||||
that the examples are written in a way that maximizes the likelihood of the agent to generating correct,
|
||||
executable code in practice.
|
||||
|
||||
Let's have a look at one example:
|
||||
|
||||
````text
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
````
|
||||
|
||||
The pattern the model is prompted to repeat has three parts: The task statement, the agent's explanation of
|
||||
what it intends to do, and finally the generated code. Every example that is part of the prompt has this exact
|
||||
pattern, thus making sure that the agent will reproduce exactly the same pattern when generating new tokens.
|
||||
|
||||
The prompt examples are curated by the Transformers team and rigorously evaluated on a set of
|
||||
[problem statements](https://github.com/huggingface/transformers/blob/main/src/transformers/tools/evaluate_agent.py)
|
||||
to ensure that the agent's prompt is as good as possible to solve real use cases of the agent.
|
||||
|
||||
The final part of the prompt corresponds to:
|
||||
```text
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
```
|
||||
|
||||
is a final and unfinished example that the agent is tasked to complete. The unfinished example
|
||||
is dynamically created based on the actual user input. For the above example, the user ran:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
The user input - *a.k.a* the task: *"Draw me a picture of rivers and lakes"* is cast into the
|
||||
prompt template: "Task: <task> \n\n I will use the following". This sentence makes up the final lines of the
|
||||
prompt the agent is conditioned on, therefore strongly influencing the agent to finish the example
|
||||
exactly in the same way it was previously done in the examples.
|
||||
|
||||
Without going into too much detail, the chat template has the same prompt structure with the
|
||||
examples having a slightly different style, *e.g.*:
|
||||
|
||||
````text
|
||||
[...]
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
[...]
|
||||
````
|
||||
|
||||
Contrary, to the examples of the `run` prompt, each `chat` prompt example has one or more exchanges between the
|
||||
*Human* and the *Assistant*. Every exchange is structured similarly to the example of the `run` prompt.
|
||||
The user's input is appended to behind *Human:* and the agent is prompted to first generate what needs to be done
|
||||
before generating code. An exchange can be based on previous exchanges, therefore allowing the user to refer
|
||||
to past exchanges as is done *e.g.* above by the user's input of "I tried **this** code" refers to the
|
||||
previously generated code of the agent.
|
||||
|
||||
Upon running `.chat`, the user's input or *task* is cast into an unfinished example of the form:
|
||||
```text
|
||||
Human: <user-input>\n\nAssistant:
|
||||
```
|
||||
which the agent completes. Contrary to the `run` command, the `chat` command then appends the completed example
|
||||
to the prompt, thus giving the agent more context for the next `chat` turn.
|
||||
|
||||
Great now that we know how the prompt is structured, let's see how we can customize it!
|
||||
|
||||
### Writing good user inputs
|
||||
|
||||
While large language models are getting better and better at understanding users' intentions, it helps
|
||||
enormously to be as precise as possible to help the agent pick the correct task. What does it mean to be
|
||||
as precise as possible?
|
||||
|
||||
The agent sees a list of tool names and their description in its prompt. The more tools are added the
|
||||
more difficult it becomes for the agent to choose the correct tool and it's even more difficult to choose
|
||||
the correct sequences of tools to run. Let's look at a common failure case, here we will only return
|
||||
the code to analyze it.
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
|
||||
agent.run("Show me a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_segmenter` to create a segmentation mask for the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
mask = image_segmenter(image, prompt="tree")
|
||||
```
|
||||
|
||||
which is probably not what we wanted. Instead, it is more likely that we want an image of a tree to be generated.
|
||||
To steer the agent more towards using a specific tool it can therefore be very helpful to use important keywords that
|
||||
are present in the tool's name and description. Let's have a look.
|
||||
```py
|
||||
agent.toolbox["image_generator"].description
|
||||
```
|
||||
|
||||
```text
|
||||
'This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which contains the image description and outputs an image.
|
||||
```
|
||||
|
||||
The name and description make use of the keywords "image", "prompt", "create" and "generate". Using these words will most likely work better here. Let's refine our prompt a bit.
|
||||
|
||||
```py
|
||||
agent.run("Create an image of a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool `image_generator` to generate an image of a tree.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="tree")
|
||||
```
|
||||
|
||||
Much better! That looks more like what we want. In short, when you notice that the agent struggles to
|
||||
correctly map your task to the correct tools, try looking up the most pertinent keywords of the tool's name
|
||||
and description and try refining your task request with it.
|
||||
|
||||
### Customizing the tool descriptions
|
||||
|
||||
As we've seen before the agent has access to each of the tools' names and descriptions. The base tools
|
||||
should have very precise names and descriptions, however, you might find that it could help to change the
|
||||
the description or name of a tool for your specific use case. This might become especially important
|
||||
when you've added multiple tools that are very similar or if you want to use your agent only for a certain
|
||||
domain, *e.g.* image generation and transformations.
|
||||
|
||||
A common problem is that the agent confuses image generation with image transformation/modification when
|
||||
used a lot for image generation tasks, *e.g.*
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
returns
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools `image_generator` to generate an image of a house and `image_transformer` to transform the image of a car into the image of a house.
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
house_car_image = image_transformer(image=car_image, prompt="A house")
|
||||
```
|
||||
|
||||
which is probably not exactly what we want here. It seems like the agent has a difficult time
|
||||
to understand the difference between `image_generator` and `image_transformer` and often uses the two together.
|
||||
|
||||
We can help the agent here by changing the tool name and description of `image_transformer`. Let's instead call it `modifier`
|
||||
to disassociate it a bit from "image" and "prompt":
|
||||
```py
|
||||
agent.toolbox["modifier"] = agent.toolbox.pop("image_transformer")
|
||||
agent.toolbox["modifier"].description = agent.toolbox["modifier"].description.replace(
|
||||
"transforms an image according to a prompt", "modifies an image"
|
||||
)
|
||||
```
|
||||
|
||||
Now "modify" is a strong cue to use the new image processor which should help with the above prompt. Let's run it again.
|
||||
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
|
||||
Now we're getting:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `image_generator` to generate an image of a house, then `image_generator` to generate an image of a car.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
```
|
||||
|
||||
which is definitely closer to what we had in mind! However, we want to have both the house and car in the same image. Steering the task more toward single image generation should help:
|
||||
|
||||
```py
|
||||
agent.run("Create image: 'A house and car'", return_code=True)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="A house and car")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Agents are still brittle for many use cases, especially when it comes to
|
||||
slightly more complex use cases like generating an image of multiple objects.
|
||||
Both the agent itself and the underlying prompt will be further improved in the coming
|
||||
months making sure that agents become more robust to a variety of user inputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Customizing the whole prompt
|
||||
|
||||
To give the user maximum flexibility, the whole prompt template as explained in [above](#structure-of-the-prompt)
|
||||
can be overwritten by the user. In this case make sure that your custom prompt includes an introduction section,
|
||||
a tool section, an example section, and an unfinished example section. If you want to overwrite the `run` prompt template,
|
||||
you can do as follows:
|
||||
|
||||
```py
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(your_endpoint, run_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string and the `<<prompt>>` defined somewhere in the `template` so that the agent can be aware
|
||||
of the tools, it has available to it as well as correctly insert the user's prompt.
|
||||
|
||||
</Tip>
|
||||
|
||||
Similarly, one can overwrite the `chat` prompt template. Note that the `chat` mode always uses the following format for the exchanges:
|
||||
```text
|
||||
Human: <<task>>
|
||||
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Therefore it is important that the examples of the custom `chat` prompt template also make use of this format.
|
||||
You can overwrite the `chat` template at instantiation as follows.
|
||||
|
||||
```
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(url_endpoint=your_endpoint, chat_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be aware
|
||||
of the tools, it has available to it.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using custom tools
|
||||
|
||||
In this section, we'll be leveraging two existing custom tools that are specific to image generation:
|
||||
|
||||
- We replace [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation),
|
||||
with [diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool)
|
||||
to allow for more image modifications.
|
||||
- We add a new tool for image upscaling to the default toolbox:
|
||||
[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool) replace the existing image-transformation tool.
|
||||
|
||||
We'll start by loading the custom tools with the convenient [`load_tool`] function:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
|
||||
upscaler = load_tool("diffusers/latent-upscaler-tool")
|
||||
```
|
||||
|
||||
Upon adding custom tools to an agent, the tools' descriptions and names are automatically
|
||||
included in the agents' prompts. Thus, it is imperative that custom tools have
|
||||
a well-written description and name in order for the agent to understand how to use them.
|
||||
Let's take a look at the description and name of `controlnet_transformer`:
|
||||
|
||||
```py
|
||||
print(f"Description: '{controlnet_transformer.description}'")
|
||||
print(f"Name: '{controlnet_transformer.name}'")
|
||||
```
|
||||
|
||||
gives
|
||||
```text
|
||||
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
|
||||
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
|
||||
Name: 'image_transformer'
|
||||
```
|
||||
|
||||
The name and description are accurate and fit the style of the [curated set of tools](./transformers_agents#a-curated-set-of-tools).
|
||||
Next, let's instantiate an agent with `controlnet_transformer` and `upscaler`:
|
||||
|
||||
```py
|
||||
tools = [controlnet_transformer, upscaler]
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
|
||||
```
|
||||
|
||||
This command should give you the following info:
|
||||
|
||||
```text
|
||||
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
|
||||
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
|
||||
```
|
||||
|
||||
The set of curated tools already has an `image_transformer` tool which is hereby replaced with our custom tool.
|
||||
|
||||
<Tip>
|
||||
|
||||
Overwriting existing tools can be beneficial if we want to use a custom tool exactly for the same task as an existing tool
|
||||
because the agent is well-versed in using the specific task. Beware that the custom tool should follow the exact same API
|
||||
as the overwritten tool in this case, or you should adapt the prompt template to make sure all examples using that
|
||||
tool are updated.
|
||||
|
||||
</Tip>
|
||||
|
||||
The upscaler tool was given the name `image_upscaler` which is not yet present in the default toolbox and is therefore simply added to the list of tools.
|
||||
You can always have a look at the toolbox that is currently available to the agent via the `agent.toolbox` attribute:
|
||||
|
||||
```py
|
||||
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
|
||||
```
|
||||
|
||||
```text
|
||||
- document_qa
|
||||
- image_captioner
|
||||
- image_qa
|
||||
- image_segmenter
|
||||
- transcriber
|
||||
- summarizer
|
||||
- text_classifier
|
||||
- text_qa
|
||||
- text_reader
|
||||
- translator
|
||||
- image_transformer
|
||||
- text_downloader
|
||||
- image_generator
|
||||
- video_generator
|
||||
- image_upscaler
|
||||
```
|
||||
|
||||
Note how `image_upscaler` is now part of the agents' toolbox.
|
||||
|
||||
Let's now try out the new tools! We will re-use the image we generated in [Transformers Agents Quickstart](./transformers_agents#single-execution-run).
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
Let's transform the image into a beautiful winter landscape:
|
||||
|
||||
```py
|
||||
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_transformer` to transform the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_transformer(image, prompt="A frozen lake and snowy forest")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
|
||||
|
||||
The new image processing tool is based on ControlNet which can make very strong modifications to the image.
|
||||
By default the image processing tool returns an image of size 512x512 pixels. Let's see if we can upscale it.
|
||||
|
||||
```py
|
||||
image = agent.run("Upscale the image", image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_upscaler` to upscale the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
upscaled_image = image_upscaler(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
|
||||
|
||||
The agent automatically mapped our prompt "Upscale the image" to the just added upscaler tool purely based on the description and name of the upscaler tool
|
||||
and was able to correctly run it.
|
||||
|
||||
Next, let's have a look at how you can create a new custom tool.
|
||||
|
||||
### Adding new tools
|
||||
|
||||
In this section, we show how to create a new tool that can be added to the agent.
|
||||
|
||||
#### Creating a new tool
|
||||
|
||||
We'll first start by creating a tool. We'll add the not-so-useful yet fun task of fetching the model on the Hugging Face
|
||||
Hub with the most downloads for a given task.
|
||||
|
||||
We can do that with the following code:
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
For the task `text-classification`, this returns `'facebook/bart-large-mnli'`, for `translation` it returns `'t5-base`.
|
||||
|
||||
How do we convert this to a tool that the agent can leverage? All tools depend on the superclass `Tool` that holds the
|
||||
main attributes necessary. We'll create a class that inherits from it:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
pass
|
||||
```
|
||||
|
||||
This class has a few needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. To be in tune with other tools which have a
|
||||
performative name, we'll name it `model_download_counter`.
|
||||
- An attribute `description`, which will be used to populate the prompt of the agent.
|
||||
- `inputs` and `outputs` attributes. Defining this will help the python interpreter make educated choices about types,
|
||||
and will allow for a gradio-demo to be spawned when we push our tool to the Hub. They're both a list of expected
|
||||
values, which can be `text`, `image`, or `audio`.
|
||||
- A `__call__` method which contains the inference code. This is the code we've played with above!
|
||||
|
||||
Here's what our class looks like now:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
|
||||
"returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __call__(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
We now have our tool handy. Save it in a file and import it from your main script. Let's name this file
|
||||
`model_downloads.py`, so the resulting import code looks like this:
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
In order to let others benefit from it and for simpler initialization, we recommend pushing it to the Hub under your
|
||||
namespace. To do so, just call `push_to_hub` on the `tool` variable:
|
||||
|
||||
```python
|
||||
tool.push_to_hub("hf-model-downloads")
|
||||
```
|
||||
|
||||
You now have your code on the Hub! Let's take a look at the final step, which is to have the agent use it.
|
||||
|
||||
#### Having the agent use the tool
|
||||
|
||||
We now have our tool that lives on the Hub which can be instantiated as such (change the user name for your tool):
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
In order to use it in the agent, simply pass it in the `additional_tools` parameter of the agent initialization method:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
which outputs the following:
|
||||
```text
|
||||
==Code generated by the agent==
|
||||
model = model_download_counter(task="text-to-video")
|
||||
print(f"The model with the most downloads is {model}.")
|
||||
audio_model = text_reader(model)
|
||||
|
||||
|
||||
==Result==
|
||||
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
|
||||
```
|
||||
|
||||
and generates the following audio.
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the LLM, some are quite brittle and require very exact prompts in order to work well. Having a well-defined
|
||||
name and description of the tool is paramount to having it be leveraged by the agent.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Replacing existing tools
|
||||
|
||||
Replacing existing tools can be done simply by assigning a new item to the agent's toolbox. Here's how one would do so:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent, load_tool
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Beware when replacing tools with others! This will also adjust the agent's prompt. This can be good if you have a better
|
||||
prompt suited for the task, but it can also result in your tool being selected way more than others or for other
|
||||
tools to be selected instead of the one you have defined.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Leveraging gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces to be designed with it.
|
||||
|
||||
We offer support for `gradio_tools` by using the `Tool.from_gradio` method. For example, we want to take
|
||||
advantage of the `StableDiffusionPromptGeneratorTool` tool offered in the `gradio-tools` toolkit so as to
|
||||
improve our prompts and generate better images.
|
||||
|
||||
We first import the tool from `gradio_tools` and instantiate it:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
|
||||
gradio_tool = StableDiffusionPromptGeneratorTool()
|
||||
```
|
||||
|
||||
We pass that instance to the `Tool.from_gradio` method:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
tool = Tool.from_gradio(gradio_tool)
|
||||
```
|
||||
|
||||
Now we can manage it exactly as we would a usual custom tool. We leverage it to improve our prompt
|
||||
` a rabbit wearing a space suit`:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
improved_prompt = StableDiffusionPromptGenerator(prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(improved_prompt)
|
||||
```
|
||||
|
||||
Before finally generating the image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
gradio-tools requires *textual* inputs and outputs, even when working with different modalities. This implementation
|
||||
works with image and audio objects. The two are currently incompatible, but will rapidly become compatible as we
|
||||
work to improve the support.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Future compatibility with Langchain
|
||||
|
||||
We love Langchain and think it has a very compelling suite of tools. In order to handle these tools,
|
||||
Langchain requires *textual* inputs and outputs, even when working with different modalities.
|
||||
This is often the serialized version (i.e., saved to disk) of the objects.
|
||||
|
||||
This difference means that multi-modality isn't handled between transformers-agents and langchain.
|
||||
We aim for this limitation to be resolved in future versions, and welcome any help from avid langchain
|
||||
users to help us achieve this compatibility.
|
||||
|
||||
We would love to have better support. If you would like to help, please
|
||||
[open an issue](https://github.com/huggingface/transformers/issues/new) and share what you have in mind.
|
64
docs/source/en/main_classes/agent.mdx
Normal file
64
docs/source/en/main_classes/agent.mdx
Normal file
@ -0,0 +1,64 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Agents & Tools
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
To learn more about agents and tools make sure to read the [introductory guide](../transformers_agents). This page
|
||||
contains the API docs for the underlying classes.
|
||||
|
||||
## Agents
|
||||
|
||||
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
### Agent
|
||||
|
||||
[[autodoc]] Agent
|
||||
- chat
|
||||
- run
|
||||
- prepare_for_new_chat
|
||||
|
||||
## Tools
|
||||
|
||||
### load_tool
|
||||
|
||||
[[autodoc]] load_tool
|
||||
|
||||
### Tool
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### PipelineTool
|
||||
|
||||
[[autodoc]] PipelineTool
|
||||
|
||||
### RemoteTool
|
||||
|
||||
[[autodoc]] RemoteTool
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
@ -212,12 +212,20 @@ Example:
|
||||
```"""
|
||||
|
||||
```
|
||||
3 steps are required to debug the docstring examples:
|
||||
1. In order to properly run the test, **an extra line has to be added** at the end of the docstring. This can be automatically done on any file using:
|
||||
```bash
|
||||
python utils/prepare_for_doc_test.py <path_to_file_or_dir>
|
||||
```
|
||||
|
||||
Just run the following line to automatically test every docstring example in the desired file:
|
||||
2. Then, you can use the following line to automatically test every docstring example in the desired file:
|
||||
```bash
|
||||
pytest --doctest-modules <path_to_file_or_dir>
|
||||
```
|
||||
If the file has a markdown extention, you should add the `--doctest-glob="*.mdx"` argument.
|
||||
3. Once you are done debugging, you need to remove the extra line added in step **1.** by running the following:
|
||||
```bash
|
||||
python utils/prepare_for_doc_test.py <path_to_file_or_dir> --remove_new_line
|
||||
```
|
||||
|
||||
### Run only modified tests
|
||||
|
||||
|
331
docs/source/en/transformers_agents.mdx
Normal file
331
docs/source/en/transformers_agents.mdx
Normal file
@ -0,0 +1,331 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Transformers Agent
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Transformers version v4.29.0, building on the concept of *tools* and *agents*. You can play with in
|
||||
[this colab](https://colab.research.google.com/drive/1c7MHD-T1forUPGcC_jlwsIptOzpG3hSj).
|
||||
|
||||
In short, it provides a natural language API on top of transformers: we define a set of curated tools and design an
|
||||
agent to interpret natural language and to use these tools. It is extensible by design; we curated some relevant tools,
|
||||
but we'll show you how the system can be extended easily to use any tool developed by the community.
|
||||
|
||||
Let's start with a few examples of what can be achieved with this new API. It is particularly powerful when it comes
|
||||
to multimodal tasks, so let's take it for a spin to generate images and read text out loud.
|
||||
|
||||
```py
|
||||
agent.run("Caption the following image", image=image)
|
||||
```
|
||||
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|-----------------------------------|
|
||||
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png" width=200> | A beaver is swimming in the water |
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run("Read the following text out loud", text=text)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|
|
||||
| A beaver is swimming in the water | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tts_example.wav" type="audio/wav"> your browser does not support the audio element. </audio>
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|----------------|
|
||||
| <img src="https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/0/image/image.jpg" width=200> | ballroom foyer |
|
||||
|
||||
## Quickstart
|
||||
|
||||
Before being able to use `agent.run`, you will need to instantiate an agent, which is a large language model (LLM).
|
||||
We provide support for openAI models as well as opensource alternatives from BigCode and OpenAssistant. The openAI
|
||||
models perform better (but require you to have an openAI API key, so cannot be used for free); Hugging Face is
|
||||
providing free access to endpoints for BigCode and OpenAssistant models.
|
||||
|
||||
To start with, please install the `agents` extras in order to install all default dependencies.
|
||||
```bash
|
||||
pip install transformers[agents]
|
||||
```
|
||||
|
||||
To use openAI models, you instantiate an [`OpenAiAgent`] after installing the `openai` dependency:
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key="<your_api_key>")
|
||||
```
|
||||
|
||||
To use BigCode or OpenAssistant, start by logging in to have access to the Inference API:
|
||||
|
||||
```py
|
||||
from huggingface_hub import login
|
||||
|
||||
login("<YOUR_TOKEN>")
|
||||
```
|
||||
|
||||
Then, instantiate the agent
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
# Starcoder
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
# StarcoderBase
|
||||
# agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoderbase")
|
||||
# OpenAssistant
|
||||
# agent = HfAgent(url_endpoint="https://api-inference.huggingface.co/models/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5")
|
||||
```
|
||||
|
||||
This is using the inference API that Hugging Face provides for free at the moment. If you have your own inference
|
||||
endpoint for this model (or another one) you can replace the URL above with your URL endpoint.
|
||||
|
||||
<Tip>
|
||||
|
||||
StarCoder and OpenAssistant are free to use and perform admirably well on simple tasks. However, the checkpoints
|
||||
don't hold up when handling more complex prompts. If you're facing such an issue, we recommend trying out the OpenAI
|
||||
model which, while sadly not open-source, performs better at this given time.
|
||||
|
||||
</Tip>
|
||||
|
||||
You're now good to go! Let's dive into the two APIs that you now have at your disposal.
|
||||
|
||||
### Single execution (run)
|
||||
|
||||
The single execution method is when using the [`~Agent.run`] method of the agent:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
It automatically selects the tool (or tools) appropriate for the task you want to perform and runs them appropriately. It
|
||||
can perform one or several tasks in the same instruction (though the more complex your instruction, the more likely
|
||||
the agent is to fail).
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of the sea then transform the picture to add an island")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sea_and_island.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
|
||||
Every [`~Agent.run`] operation is independent, so you can run it several times in a row with different tasks.
|
||||
|
||||
Note that your `agent` is just a large-language model, so small variations in your prompt might yield completely
|
||||
different results. It's important to explain as clearly as possible the task you want to perform. We go more in-depth
|
||||
on how to write good prompts [here](custom_tools#writing-good-user-inputs).
|
||||
|
||||
If you'd like to keep a state across executions or to pass non-text objects to the agent, you can do so by specifying
|
||||
variables that you would like the agent to use. For example, you could generate the first image of rivers and lakes,
|
||||
and ask the model to update that picture to add an island by doing the following:
|
||||
|
||||
```python
|
||||
picture = agent.run("Generate a picture of rivers and lakes.")
|
||||
updated_picture = agent.run("Transform the image in `picture` to add an island to it.", picture=picture)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
This can be helpful when the model is unable to understand your request and mixes tools. An example would be:
|
||||
|
||||
```py
|
||||
agent.run("Draw me the picture of a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
Here, the model could interpret in two ways:
|
||||
- Have the `text-to-image` generate a capybara swimming in the sea
|
||||
- Or, have the `text-to-image` generate capybara, then use the `image-transformation` tool to have it swim in the sea
|
||||
|
||||
In case you would like to force the first scenario, you could do so by passing it the prompt as an argument:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
### Chat-based execution (chat)
|
||||
|
||||
The agent also has a chat-based approach, using the [`~Agent.chat`] method:
|
||||
|
||||
```py
|
||||
agent.chat("Generate a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
```py
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_and_beaver.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
This is an interesting approach when you want to keep the state across instructions. It's better for experimentation,
|
||||
but will tend to be much better at single instructions rather than complex instructions (which the [`~Agent.run`]
|
||||
method is better at handling).
|
||||
|
||||
This method can also take arguments if you would like to pass non-text types or specific prompts.
|
||||
|
||||
### ⚠️ Remote execution
|
||||
|
||||
For demonstration purposes and so that this can be used with all setups, we have created remote executors for several
|
||||
of the default tools the agent has access. These are created using
|
||||
[inference endpoints](https://huggingface.co/inference-endpoints). To see how to set up remote executors tools yourself,
|
||||
we recommend reading the [custom tool guide](./custom_tools).
|
||||
|
||||
In order to run with remote tools, specifying `remote=True` to either [`~Agent.run`] or [`~Agent.chat`] is sufficient.
|
||||
|
||||
For example, the following command could be run on any device efficiently, without needing significant RAM or GPU:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes", remote=True)
|
||||
```
|
||||
|
||||
The same can be said for [`~Agent.chat`]:
|
||||
|
||||
```py
|
||||
agent.chat("Draw me a picture of rivers and lakes", remote=True)
|
||||
```
|
||||
|
||||
### What's happening here? What are tools, and what are agents?
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/diagram.png">
|
||||
|
||||
#### Agents
|
||||
|
||||
The "agent" here is a large language model, and we're prompting it so that it has access to a specific set of tools.
|
||||
|
||||
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the
|
||||
LLM gives a small sample of code performing a task with a set of tools. This prompt is then completed by the
|
||||
task you give your agent and the description of the tools you give it. This way it gets access to the doc of the
|
||||
tools you are using, especially their expected inputs and outputs, and can generate the relevant code.
|
||||
|
||||
#### Tools
|
||||
|
||||
Tools are very simple: they're a single function, with a name, and a description. We then use these tools' descriptions
|
||||
to prompt the agent. Through the prompt, we show the agent how it would leverage tools to perform what was
|
||||
requested in the query.
|
||||
|
||||
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools.
|
||||
Pipelines are more refactored and often combine several tasks in one. Tools are meant to be focused on
|
||||
one very simple task only.
|
||||
|
||||
#### Code-execution?!
|
||||
|
||||
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools.
|
||||
We hear you screaming "Arbitrary code execution!" in the back, but let us explain why that is not the case.
|
||||
|
||||
The only functions that can be called are the tools you provided and the print function, so you're already
|
||||
limited in what can be executed. You should be safe if it's limited to Hugging Face tools.
|
||||
|
||||
Then, we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along
|
||||
inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM
|
||||
to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the
|
||||
run() method with the additional argument return_code=True, in which case the agent will just return the code
|
||||
to execute and you can decide whether to do it or not.
|
||||
|
||||
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error
|
||||
with the code generated by the agent.
|
||||
|
||||
### A curated set of tools
|
||||
|
||||
We identify a set of tools that can empower such agents. Here is an updated list of the tools we have integrated
|
||||
in `transformers`:
|
||||
|
||||
- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](./model_doc/donut))
|
||||
- **Text question answering**: given a long text and a question, answer the question in the text ([Flan-T5](./model_doc/flan-t5))
|
||||
- **Unconditional image captioning**: Caption the image! ([BLIP](./model_doc/blip))
|
||||
- **Image question answering**: given an image, answer a question on this image ([VILT](./model_doc/vilt))
|
||||
- **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt ([CLIPSeg](./model_doc/clipseg))
|
||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||
- **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most ([BART](./model_doc/bart))
|
||||
- **Text summarization**: summarize a long text in one or a few sentences ([BART](./model_doc/bart))
|
||||
- **Translation**: translate the text into a given language ([NLLB](./model_doc/nllb))
|
||||
|
||||
These tools have an integration in transformers, and can be used manually as well, for example:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("text-to-speech")
|
||||
audio = tool("This is a text to speech tool")
|
||||
```
|
||||
|
||||
### Custom tools
|
||||
|
||||
While we identify a curated set of tools, we strongly believe that the main value provided by this implementation is
|
||||
the ability to quickly create and share custom tools.
|
||||
|
||||
By pushing the code of a tool to a Hugging Face Space or a model repository, you're then able to leverage the tool
|
||||
directly with the agent. We've added a few
|
||||
**transformers-agnostic** tools to the [`huggingface-tools` organization](https://huggingface.co/huggingface-tools):
|
||||
|
||||
- **Text downloader**: to download a text from a web URL
|
||||
- **Text to image**: generate an image according to a prompt, leveraging stable diffusion
|
||||
- **Image transformation**: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
- **Text to video**: generate a small video according to a prompt, leveraging damo-vilab
|
||||
|
||||
The text-to-image tool we have been using since the beginning is a remote tool that lives in
|
||||
[*huggingface-tools/text-to-image*](https://huggingface.co/spaces/huggingface-tools/text-to-image)! We will
|
||||
continue releasing such tools on this and other organizations, to further supercharge this implementation.
|
||||
|
||||
The agents have by default access to tools that reside on [`huggingface-tools`](https://huggingface.co/huggingface-tools).
|
||||
We explain how to you can write and share your tools as well as leverage any custom tool that resides on the Hub in [following guide](custom_tools).
|
||||
|
||||
### Code generation
|
||||
|
||||
So far we have shown how to use the agents to perform actions for you. However, the agent is only generating code
|
||||
that we then execute using a very restricted Python interpreter. In case you would like to use the code generated in
|
||||
a different setting, the agent can be prompted to return the code, along with tool definition and accurate imports.
|
||||
|
||||
For example, the following instruction
|
||||
```python
|
||||
agent.run("Draw me a picture of rivers and lakes", return_code=True)
|
||||
```
|
||||
|
||||
returns the following code
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
image_generator = load_tool("huggingface-tools/text-to-image")
|
||||
|
||||
image = image_generator(prompt="rivers and lakes")
|
||||
```
|
||||
|
||||
that you can then modify and execute yourself.
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, get_full_repo_name, send_examp
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils import check_min_version, get_full_repo_name, send_examp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.25.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, get_full_repo
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version,
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.29.0.dev0")
|
||||
check_min_version("4.29.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -1,3 +1,2 @@
|
||||
[tool:pytest]
|
||||
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
||||
doctest_glob=**/*.mdx
|
10
setup.py
10
setup.py
@ -112,6 +112,7 @@ _deps = [
|
||||
"datasets!=2.5.0",
|
||||
"decord==0.6.0",
|
||||
"deepspeed>=0.8.3",
|
||||
"diffusers",
|
||||
"dill<0.3.5",
|
||||
"evaluate>=0.2.0",
|
||||
"fairscale>0.3",
|
||||
@ -123,7 +124,7 @@ _deps = [
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.11.0,<1.0",
|
||||
"huggingface-hub>=0.14.1,<1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
@ -140,6 +141,7 @@ _deps = [
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"opencv-python",
|
||||
"optuna",
|
||||
"optax>=0.0.8,<=0.1.4",
|
||||
"packaging>=20.0",
|
||||
@ -412,6 +414,10 @@ extras["torchhub"] = deps_list(
|
||||
"tqdm",
|
||||
)
|
||||
|
||||
extras["agents"] = deps_list(
|
||||
"diffusers", "accelerate", "datasets", "torch", "sentencepiece", "opencv-python", "Pillow"
|
||||
)
|
||||
|
||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
install_requires = [
|
||||
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
|
||||
@ -428,7 +434,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.29.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.29.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.29.0.dev0"
|
||||
__version__ = "4.29.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -610,6 +610,16 @@ _import_structure = {
|
||||
"SpecialTokensMixin",
|
||||
"TokenSpan",
|
||||
],
|
||||
"tools": [
|
||||
"Agent",
|
||||
"HfAgent",
|
||||
"OpenAiAgent",
|
||||
"PipelineTool",
|
||||
"RemoteTool",
|
||||
"Tool",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
],
|
||||
"trainer_callback": [
|
||||
"DefaultFlowCallback",
|
||||
"EarlyStoppingCallback",
|
||||
@ -4340,6 +4350,9 @@ if TYPE_CHECKING:
|
||||
TokenSpan,
|
||||
)
|
||||
|
||||
# Tools
|
||||
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
|
||||
# Trainer
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
|
@ -13,6 +13,7 @@ deps = {
|
||||
"datasets": "datasets!=2.5.0",
|
||||
"decord": "decord==0.6.0",
|
||||
"deepspeed": "deepspeed>=0.8.3",
|
||||
"diffusers": "diffusers",
|
||||
"dill": "dill<0.3.5",
|
||||
"evaluate": "evaluate>=0.2.0",
|
||||
"fairscale": "fairscale>0.3",
|
||||
@ -24,7 +25,7 @@ deps = {
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.11.0,<1.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.14.1,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
@ -41,6 +42,7 @@ deps = {
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"opencv-python": "opencv-python",
|
||||
"optuna": "optuna",
|
||||
"optax": "optax>=0.0.8,<=0.1.4",
|
||||
"packaging": "packaging>=20.0",
|
||||
|
@ -115,9 +115,9 @@ def get_relative_import_files(module_file):
|
||||
return all_relative_imports
|
||||
|
||||
|
||||
def check_imports(filename):
|
||||
def get_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
Extracts all the libraries that are imported in a file.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@ -131,9 +131,14 @@ def check_imports(filename):
|
||||
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||
# Only keep the top-level module
|
||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||
return list(set(imports))
|
||||
|
||||
# Unique-ify and test we got them all
|
||||
imports = list(set(imports))
|
||||
|
||||
def check_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
"""
|
||||
imports = get_imports(filename)
|
||||
missing_packages = []
|
||||
for imp in imports:
|
||||
try:
|
||||
@ -169,6 +174,7 @@ def get_cached_module_file(
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
_commit_hash: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -207,6 +213,8 @@ def get_cached_module_file(
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -229,7 +237,7 @@ def get_cached_module_file(
|
||||
else:
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
cached_module = try_to_load_from_cache(
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
|
||||
new_files = []
|
||||
@ -245,6 +253,7 @@ def get_cached_module_file(
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
if not is_local and cached_module != resolved_module_file:
|
||||
@ -309,8 +318,10 @@ def get_cached_module_file(
|
||||
|
||||
if len(new_files) > 0:
|
||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
||||
repo_type_str = "" if repo_type is None else f"{repo_type}/"
|
||||
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
||||
logger.warning(
|
||||
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
|
||||
f"A new version of the following files was downloaded from {url}:\n{new_files}"
|
||||
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
||||
"versions of the code file, you can pin a revision."
|
||||
)
|
||||
@ -328,6 +339,7 @@ def get_class_from_dynamic_module(
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -377,6 +389,8 @@ def get_class_from_dynamic_module(
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -418,6 +432,7 @@ def get_class_from_dynamic_module(
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
|
||||
@ -439,6 +454,7 @@ def custom_object_save(obj, folder, config=None):
|
||||
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
|
||||
"the Hub."
|
||||
)
|
||||
return
|
||||
|
||||
def _set_auto_map_in_config(_config):
|
||||
module_name = obj.__class__.__module__
|
||||
@ -478,12 +494,17 @@ def custom_object_save(obj, folder, config=None):
|
||||
elif config is not None:
|
||||
_set_auto_map_in_config(config)
|
||||
|
||||
result = []
|
||||
# Copy module file to the output folder.
|
||||
object_file = sys.modules[obj.__module__].__file__
|
||||
dest_file = Path(folder) / (Path(object_file).name)
|
||||
shutil.copy(object_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
# Gather all relative imports recursively and make sure they are copied as well.
|
||||
for needed_file in get_relative_import_files(object_file):
|
||||
dest_file = Path(folder) / (Path(needed_file).name)
|
||||
shutil.copy(needed_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
return result
|
||||
|
@ -64,6 +64,10 @@ class ChannelDimension(ExplicitEnum):
|
||||
LAST = "channels_last"
|
||||
|
||||
|
||||
def is_pil_image(img):
|
||||
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
||||
|
||||
|
||||
def is_valid_image(img):
|
||||
return (
|
||||
(is_vision_available() and isinstance(img, PIL.Image.Image))
|
||||
|
@ -207,29 +207,21 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
|
||||
for t in parameter.buffers():
|
||||
last_dtype = t.dtype
|
||||
if t.is_floating_point():
|
||||
return t.dtype
|
||||
else:
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
if last_dtype is not None:
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
def get_state_dict_float_dtype(state_dict):
|
||||
|
@ -407,8 +407,7 @@ class _BaseAutoModelClass:
|
||||
repo_id, class_ref = class_ref.split("--")
|
||||
else:
|
||||
repo_id = config.name_or_path
|
||||
module_file, class_name = class_ref.split(".")
|
||||
model_class = get_class_from_dynamic_module(repo_id, module_file + ".py", class_name, **kwargs)
|
||||
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
||||
return model_class._from_config(config, **kwargs)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
|
@ -148,6 +148,7 @@ _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=Fa
|
||||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
|
||||
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
|
||||
|
||||
|
||||
def is_pt_tf_cross_test(test_case):
|
||||
@ -221,6 +222,21 @@ def is_pipeline_test(test_case):
|
||||
return pytest.mark.is_pipeline_test()(test_case)
|
||||
|
||||
|
||||
def is_tool_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
|
||||
"""
|
||||
if not _run_tool_tests:
|
||||
return unittest.skip("test is a tool test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_tool_test()(test_case)
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""
|
||||
Decorator marking a test as slow.
|
||||
|
73
src/transformers/tools/__init__.py
Normal file
73
src/transformers/tools/__init__.py
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
|
||||
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
||||
_import_structure["image_captioning"] = ["ImageCaptioningTool"]
|
||||
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
||||
_import_structure["image_segmentation"] = ["ImageSegmentationTool"]
|
||||
_import_structure["language_identifier"] = ["LanguageIdentificationTool"]
|
||||
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
||||
_import_structure["text_classification"] = ["TextClassificationTool"]
|
||||
_import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
|
||||
_import_structure["text_summarization"] = ["TextSummarizationTool"]
|
||||
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
||||
_import_structure["translation"] = ["TranslationTool"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, HfAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||
from .image_captioning import ImageCaptioningTool
|
||||
from .image_question_answering import ImageQuestionAnsweringTool
|
||||
from .image_segmentation import ImageSegmentationTool
|
||||
from .language_identifier import LanguageIdentificationTool
|
||||
from .speech_to_text import SpeechToTextTool
|
||||
from .text_classification import TextClassificationTool
|
||||
from .text_question_answering import TextQuestionAnsweringTool
|
||||
from .text_summarization import TextSummarizationTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
495
src/transformers/tools/agents.py
Normal file
495
src/transformers/tools/agents.py
Normal file
@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import is_openai_available, logging
|
||||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
||||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
||||
from .python_interpreter import evaluate
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
import openai
|
||||
|
||||
_tools_are_initialized = False
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": print,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTool:
|
||||
task: str
|
||||
description: str
|
||||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-download",
|
||||
"text-to-image",
|
||||
"text-to-video",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(organization="huggingface-tools"):
|
||||
spaces = list_spaces(author=organization)
|
||||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
task = repo_id.split("/")[-1]
|
||||
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _setup_default_tools():
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
global _tools_are_initialized
|
||||
|
||||
if _tools_are_initialized:
|
||||
return
|
||||
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
|
||||
remote_tools = get_remote_tools()
|
||||
for task_name in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING.get(task_name)
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
description = tool_class.description
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
|
||||
|
||||
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
|
||||
found = False
|
||||
for tool_name, tool in remote_tools.items():
|
||||
if tool.task == task_name:
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise ValueError(f"{task_name} is not implemented on the Hub.")
|
||||
|
||||
_tools_are_initialized = True
|
||||
|
||||
|
||||
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
|
||||
if cached_tools is None:
|
||||
resolved_tools = BASE_PYTHON_TOOLS.copy()
|
||||
else:
|
||||
resolved_tools = cached_tools
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or name in resolved_tools:
|
||||
continue
|
||||
|
||||
if isinstance(tool, Tool):
|
||||
resolved_tools[name] = tool
|
||||
else:
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
_remote = remote and supports_remote(task_or_repo_id)
|
||||
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote)
|
||||
|
||||
return resolved_tools
|
||||
|
||||
|
||||
def get_tool_creation_code(code, toolbox, remote=False):
|
||||
code_lines = ["from transformers import load_tool", ""]
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or isinstance(tool, Tool):
|
||||
continue
|
||||
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
line = f'{name} = load_tool("{task_or_repo_id}"'
|
||||
if remote:
|
||||
line += ", remote=True"
|
||||
line += ")"
|
||||
code_lines.append(line)
|
||||
|
||||
return "\n".join(code_lines) + "\n"
|
||||
|
||||
|
||||
def clean_code_for_chat(result):
|
||||
lines = result.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
explanation = "\n".join(lines[:idx]).strip()
|
||||
if idx == len(lines):
|
||||
return explanation, None
|
||||
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
while not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
code = "\n".join(lines[start_idx:idx]).strip()
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
def clean_code_for_run(result):
|
||||
result = f"I will use the following {result}"
|
||||
explanation, code = result.split("Answer:")
|
||||
explanation = explanation.strip()
|
||||
code = code.strip()
|
||||
|
||||
code_lines = code.split("\n")
|
||||
if code_lines[0] in ["```", "```py", "```python"]:
|
||||
code_lines = code_lines[1:]
|
||||
if code_lines[-1] == "```":
|
||||
code_lines = code_lines[:-1]
|
||||
code = "\n".join(code_lines)
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
Base class for all agents which contains the main API methods.
|
||||
|
||||
Args:
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
"""
|
||||
|
||||
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
_setup_default_tools()
|
||||
|
||||
self.chat_prompt_template = CHAT_MESSAGE_PROMPT if chat_prompt_template is None else chat_prompt_template
|
||||
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
|
||||
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
|
||||
if additional_tools is not None:
|
||||
if isinstance(additional_tools, (list, tuple)):
|
||||
additional_tools = {t.name: t for t in additional_tools}
|
||||
elif not isinstance(additional_tools, dict):
|
||||
additional_tools = {additional_tools.name: additional_tools}
|
||||
|
||||
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
|
||||
self._toolbox.update(additional_tools)
|
||||
if len(replacements) > 1:
|
||||
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
|
||||
logger.warn(
|
||||
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
|
||||
)
|
||||
elif len(replacements) == 1:
|
||||
name = list(replacements.keys())[0]
|
||||
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
|
||||
|
||||
self.prepare_for_new_chat()
|
||||
|
||||
@property
|
||||
def toolbox(self) -> Dict[str, Tool]:
|
||||
"""Get all tool currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def format_prompt(self, task, chat_mode=False):
|
||||
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
|
||||
if chat_mode:
|
||||
if self.chat_history is None:
|
||||
prompt = CHAT_PROMPT_TEMPLATE.replace("<<all_tools>>", description)
|
||||
else:
|
||||
prompt = self.chat_history
|
||||
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
|
||||
else:
|
||||
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
|
||||
prompt = prompt.replace("<<prompt>>", task)
|
||||
return prompt
|
||||
|
||||
def chat(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a new request to the agent in a chat. Will use the previous ones in its history.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs:
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.chat("Draw me a picture of rivers and lakes")
|
||||
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task, chat_mode=True)
|
||||
result = self.generate_one(prompt, stop=["Human:", "====="])
|
||||
self.chat_history = prompt + result.strip() + "\n"
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
if code is not None:
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
self.chat_state.update(kwargs)
|
||||
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def prepare_for_new_chat(self):
|
||||
"""
|
||||
Clears the history of prior calls to [`~Agent.chat`].
|
||||
"""
|
||||
self.chat_history = None
|
||||
self.chat_state = {}
|
||||
self.cached_tools = None
|
||||
|
||||
def run(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a request to the agent.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs:
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task)
|
||||
result = self.generate_one(prompt, stop=["Task:"])
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
return evaluate(code, self.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
# This is the method to implement in your custom agent.
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
# Override if you have a way to do batch generation faster than one by one
|
||||
return [self.generate_one(prompt, stop) for prompt in prompts]
|
||||
|
||||
|
||||
class OpenAiAgent(Agent):
|
||||
"""
|
||||
Agent that uses the openai API to generate code.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
|
||||
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`str`, *optional*, defaults to `"text-davinci-003"`):
|
||||
The name of the OpenAI model to use.
|
||||
api_key (`str`, *optional*):
|
||||
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="text-davinci-003",
|
||||
api_key=None,
|
||||
chat_prompt_template=None,
|
||||
run_prompt_template=None,
|
||||
additional_tools=None,
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
|
||||
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
|
||||
"xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
self.model = model
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
if "gpt" in self.model:
|
||||
return [self._chat_generate(prompt, stop) for prompt in prompts]
|
||||
else:
|
||||
return self._completion_generate(prompts, stop)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
if "gpt" in self.model:
|
||||
return self._chat_generate(prompt, stop)
|
||||
else:
|
||||
return self._completion_generate([prompt], stop)[0]
|
||||
|
||||
def _chat_generate(self, prompt, stop):
|
||||
result = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
def _completion_generate(self, prompts, stop):
|
||||
result = openai.Completion.create(
|
||||
model=self.model,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
max_tokens=200,
|
||||
)
|
||||
return [answer["text"] for answer in result["choices"]]
|
||||
|
||||
|
||||
class HfAgent(Agent):
|
||||
"""
|
||||
Agent that uses and inference endpoint to generate code.
|
||||
|
||||
Args:
|
||||
url_endpoint (`str`):
|
||||
The name of the url endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
|
||||
):
|
||||
self.url_endpoint = url_endpoint
|
||||
if token is None:
|
||||
self.token = f"Bearer {HfFolder().get_token()}"
|
||||
elif token.startswith("Bearer") or token.startswith("Basic"):
|
||||
self.token = token
|
||||
else:
|
||||
self.token = f"Bearer {token}"
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
headers = {"Authorization": self.token}
|
||||
inputs = {
|
||||
"inputs": prompt,
|
||||
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
|
||||
}
|
||||
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code == 429:
|
||||
print("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
time.sleep(1)
|
||||
return self._generate_one(prompt)
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[: -len(stop_seq)]
|
||||
return result
|
723
src/transformers/tools/base.py
Normal file
723
src/transformers/tools/base.py
Normal file
@ -0,0 +1,723 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, hf_hub_download, metadata_update
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, get_session
|
||||
|
||||
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
|
||||
from ..image_utils import is_pil_image
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
cached_file,
|
||||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import send_to_device
|
||||
|
||||
|
||||
TOOL_CONFIG_FILE = "tool_config.json"
|
||||
|
||||
|
||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||
if repo_type is not None:
|
||||
return repo_type
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
|
||||
return "space"
|
||||
except RepositoryNotFoundError:
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||
return "model"
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||
except Exception:
|
||||
return "model"
|
||||
except Exception:
|
||||
return "space"
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||
from {module_name} import {class_name}
|
||||
|
||||
launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||
following class attributes:
|
||||
|
||||
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
|
||||
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
|
||||
returns the text contained in the file'.
|
||||
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
|
||||
`"text-classifier"` or `"image_generator"`.
|
||||
- **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
|
||||
Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
|
||||
nice space from your tool.
|
||||
- **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
|
||||
call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
|
||||
or to make a nice space from your tool.
|
||||
|
||||
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
|
||||
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
||||
instantiation.
|
||||
"""
|
||||
|
||||
description: str = "This is a tool that ..."
|
||||
name: str = ""
|
||||
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
|
||||
your tool. Such as loading a big model.
|
||||
"""
|
||||
self.is_initialized = True
|
||||
|
||||
def save(self, output_dir):
|
||||
"""
|
||||
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
||||
tool in `output_dir` as well as autogenerate:
|
||||
|
||||
- a config file named `tool_config.json`
|
||||
- an `app.py` file so that your tool can be converted to a space
|
||||
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
||||
You should only use this method to save tools that are defined in a separate module (not `__main__`).
|
||||
|
||||
Args:
|
||||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Save module file
|
||||
if self.__module__ == "__main__":
|
||||
raise ValueError(
|
||||
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
|
||||
"have to put this code in a separate module so we can include it in the saved folder."
|
||||
)
|
||||
module_files = custom_object_save(self, output_dir)
|
||||
|
||||
module_name = self.__class__.__module__
|
||||
last_module = module_name.split(".")[-1]
|
||||
full_name = f"{last_module}.{self.__class__.__name__}"
|
||||
|
||||
# Save config file
|
||||
config_file = os.path.join(output_dir, "tool_config.json")
|
||||
if os.path.isfile(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
tool_config = json.load(f)
|
||||
else:
|
||||
tool_config = {}
|
||||
|
||||
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
||||
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
imports = []
|
||||
for module in module_files:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(set(imports))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
|
||||
@classmethod
|
||||
def from_hub(cls, repo_id, model_repo_id=None, token=None, remote=False, **kwargs):
|
||||
"""
|
||||
Loads a tool defined on the Hub.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The name of the repo on the Hub where your tool is defined.
|
||||
model_repo_id (`str`, *optional*):
|
||||
If your tool uses a model and you want to use a different model than the default, you can pass a second
|
||||
repo ID or an endpoint url to this argument.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
kwargs:
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
if remote and model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[repo_id]
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"repo_type",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||
|
||||
# Try to get the tool config first.
|
||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
TOOL_CONFIG_FILE,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
is_tool_config = resolved_config_file is not None
|
||||
if resolved_config_file is None:
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
CONFIG_NAME,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||
)
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
if not is_tool_config:
|
||||
if "custom_tool" not in config:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
||||
)
|
||||
custom_tool = config["custom_tool"]
|
||||
else:
|
||||
custom_tool = config
|
||||
|
||||
tool_class = custom_tool["tool_class"]
|
||||
tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs)
|
||||
|
||||
if remote:
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
commit_message: str = "Upload tool",
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the tool to the Hub.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
pushing to a given organization.
|
||||
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
|
||||
Message to commit while pushing.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private.
|
||||
token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
|
||||
)
|
||||
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
|
||||
repo_id = repo_url.repo_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
os.listdir(work_dir)
|
||||
operations = [
|
||||
CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, f), path_in_repo=f)
|
||||
for f in os.listdir(work_dir)
|
||||
]
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
token=token,
|
||||
create_pr=create_pr,
|
||||
repo_type="space",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_gradio(gradio_tool):
|
||||
"""
|
||||
Creates a [`Tool`] from a gradio tool.
|
||||
"""
|
||||
|
||||
class GradioToolWrapper(Tool):
|
||||
def __init__(self, _gradio_tool):
|
||||
super().__init__()
|
||||
self.name = _gradio_tool.name
|
||||
self.description = _gradio_tool.description
|
||||
|
||||
GradioToolWrapper.__call__ = gradio_tool.run
|
||||
return GradioToolWrapper(gradio_tool)
|
||||
|
||||
|
||||
class RemoteTool(Tool):
|
||||
"""
|
||||
A [`Tool`] that will make requests to an inference endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_url (`str`):
|
||||
The url of the endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
tool_class (`type`, *optional*):
|
||||
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
|
||||
the output should be converted to another type (like images).
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint_url=None, token=None, tool_class=None):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.client = EndpointClient(endpoint_url, token=token)
|
||||
self.tool_class = tool_class
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
"""
|
||||
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
|
||||
matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
|
||||
bytes.
|
||||
|
||||
You can override this method in your custom class of [`RemoteTool`].
|
||||
"""
|
||||
inputs = kwargs.copy()
|
||||
if len(args) > 0:
|
||||
if self.tool_class is not None:
|
||||
# Match args with the signature
|
||||
if issubclass(self.tool_class, PipelineTool):
|
||||
call_method = self.tool_class.encode
|
||||
else:
|
||||
call_method = self.tool_class.__call__
|
||||
signature = inspect.signature(call_method).parameters
|
||||
parameters = [
|
||||
k
|
||||
for k, p in signature.items()
|
||||
if p.kind not in [inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD]
|
||||
]
|
||||
if parameters[0] == "self":
|
||||
parameters = parameters[1:]
|
||||
if len(args) > len(parameters):
|
||||
raise ValueError(
|
||||
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
|
||||
)
|
||||
for arg, name in zip(args, parameters):
|
||||
inputs[name] = arg
|
||||
elif len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
if is_pil_image(args[0]):
|
||||
return {"inputs": self.client.encode_image(args[0])}
|
||||
return {"inputs": args[0]}
|
||||
|
||||
for key, value in inputs.items():
|
||||
if is_pil_image(value):
|
||||
inputs[key] = self.client.encode_image(value)
|
||||
|
||||
return {"inputs": inputs}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
"""
|
||||
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
|
||||
outputs of the endpoint.
|
||||
"""
|
||||
return outputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs, output_image=output_image)
|
||||
else:
|
||||
outputs = self.client(inputs, output_image=output_image)
|
||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||
outputs = outputs[0]
|
||||
return self.extract_outputs(outputs)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
"""
|
||||
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
||||
need to specify:
|
||||
|
||||
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
||||
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
||||
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
pre-processor
|
||||
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
post-processor (when different from the pre-processor).
|
||||
|
||||
Args:
|
||||
model (`str` or [`PreTrainedModel`], *optional*):
|
||||
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
||||
value of the class attribute `default_checkpoint`.
|
||||
pre_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
|
||||
unset.
|
||||
post_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
|
||||
unset.
|
||||
device (`int`, `str` or `torch.device`, *optional*):
|
||||
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
|
||||
CPU otherwise.
|
||||
device_map (`str` or `dict`, *optional*):
|
||||
If passed along, will be used to instantiate the model.
|
||||
model_kwargs (`dict`, *optional*):
|
||||
Any keyword argument to send to the model instantiation.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
hub_kwargs:
|
||||
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
||||
"""
|
||||
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
token=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Please install accelerate in order to use this tool.")
|
||||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
||||
self.model = model
|
||||
self.pre_processor = pre_processor
|
||||
self.post_processor = post_processor
|
||||
self.device = device
|
||||
self.device_map = device_map
|
||||
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
||||
if device_map is not None:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.hub_kwargs = hub_kwargs
|
||||
self.hub_kwargs["use_auth_token"] = token
|
||||
|
||||
self.is_initialized = False
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||
"""
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
self.device = list(self.model.hf_device_map.values())[0]
|
||||
else:
|
||||
self.device = get_default_device()
|
||||
|
||||
if self.device_map is None:
|
||||
self.model.to(self.device)
|
||||
|
||||
def encode(self, raw_inputs):
|
||||
"""
|
||||
Uses the `pre_processor` to prepare the inputs for the `model`.
|
||||
"""
|
||||
return self.pre_processor(raw_inputs)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Sends the inputs through the `model`.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
"""
|
||||
Uses the `post_processor` to decode the model output.
|
||||
"""
|
||||
return self.post_processor(outputs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||
outputs = self.forward(encoded_inputs)
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
return self.decode(outputs)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
"""
|
||||
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
||||
`inputs` and `outputs`.
|
||||
|
||||
Args:
|
||||
tool_class (`type`): The class of the tool for which to launch the demo.
|
||||
"""
|
||||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
inputs=tool_class.inputs,
|
||||
outputs=tool_class.outputs,
|
||||
title=tool_class.__name__,
|
||||
article=tool.description,
|
||||
).launch()
|
||||
|
||||
|
||||
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
|
||||
def get_default_device():
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
return torch.device("mps")
|
||||
elif torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
||||
"image-captioning": "ImageCaptioningTool",
|
||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
||||
"image-segmentation": "ImageSegmentationTool",
|
||||
"speech-to-text": "SpeechToTextTool",
|
||||
"summarization": "TextSummarizationTool",
|
||||
"text-classification": "TextClassificationTool",
|
||||
"text-question-answering": "TextQuestionAnsweringTool",
|
||||
"text-to-speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
}
|
||||
|
||||
|
||||
def get_default_endpoints():
|
||||
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
|
||||
with open(endpoints_file, "r", encoding="utf-8") as f:
|
||||
endpoints = json.load(f)
|
||||
return endpoints
|
||||
|
||||
|
||||
def supports_remote(task_or_repo_id):
|
||||
endpoints = get_default_endpoints()
|
||||
return task_or_repo_id in endpoints
|
||||
|
||||
|
||||
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
|
||||
"""
|
||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
||||
|
||||
Args:
|
||||
task_or_repo_id (`str`):
|
||||
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
||||
are:
|
||||
|
||||
- `"document-question-answering"`
|
||||
- `"image-captioning"`
|
||||
- `"image-question-answering"`
|
||||
- `"image-segmentation"`
|
||||
- `"speech-to-text"`
|
||||
- `"summarization"`
|
||||
- `"text-classification"`
|
||||
- `"text-question-answering"`
|
||||
- `"text-to-speech"`
|
||||
- `"translation"`
|
||||
|
||||
model_repo_id (`str`, *optional*):
|
||||
Use this argument to use a different model than the default one for the tool you selected.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
|
||||
login` (stored in `~/.huggingface`).
|
||||
kwargs:
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||
will be passed along to its init.
|
||||
"""
|
||||
if task_or_repo_id in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
|
||||
if remote:
|
||||
if model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if task_or_repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[task_or_repo_id]
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
else:
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
else:
|
||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
|
||||
|
||||
|
||||
def add_description(description):
|
||||
"""
|
||||
A decorator that adds a description to a function.
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
func.description = description
|
||||
func.name = func.__name__
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
## Will move to the Hub
|
||||
class EndpointClient:
|
||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder().get_token()
|
||||
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
@staticmethod
|
||||
def encode_image(image):
|
||||
_bytes = io.BytesIO()
|
||||
image.save(_bytes, format="PNG")
|
||||
b64 = base64.b64encode(_bytes.getvalue())
|
||||
return b64.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def decode_image(raw_image):
|
||||
if not is_vision_available():
|
||||
raise ImportError(
|
||||
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
b64 = base64.b64decode(raw_image)
|
||||
_bytes = io.BytesIO(b64)
|
||||
return Image.open(_bytes)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
data: Optional[bytes] = None,
|
||||
output_image: bool = False,
|
||||
) -> Any:
|
||||
# Build payload
|
||||
payload = {}
|
||||
if inputs:
|
||||
payload["inputs"] = inputs
|
||||
if params:
|
||||
payload["parameters"] = params
|
||||
|
||||
# Make API call
|
||||
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
||||
|
||||
# By default, parse the response for the user.
|
||||
if output_image:
|
||||
return self.decode_image(response.content)
|
||||
else:
|
||||
return response.json()
|
80
src/transformers/tools/document_question_answering.py
Normal file
80
src/transformers/tools/document_question_answering.py
Normal file
@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
|
||||
"should be the document containing the information, as well as a `question` that is the question about the "
|
||||
"document. It returns a text that contains the answer to the question."
|
||||
)
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
prompt = task_prompt.replace("{user_input}", question)
|
||||
decoder_input_ids = self.pre_processor.tokenizer(
|
||||
prompt, add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids
|
||||
pixel_values = self.pre_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(
|
||||
inputs["pixel_values"].to(self.device),
|
||||
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
||||
max_length=self.model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
).sequences
|
||||
|
||||
def decode(self, outputs):
|
||||
sequence = self.pre_processor.batch_decode(outputs)[0]
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
sequence = self.pre_processor.token2json(sequence)
|
||||
|
||||
return sequence["answer"]
|
692
src/transformers/tools/evaluate_agent.py
Normal file
692
src/transformers/tools/evaluate_agent.py
Normal file
@ -0,0 +1,692 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run
|
||||
from .python_interpreter import InterpretorError, evaluate
|
||||
|
||||
|
||||
### Fake tools for test
|
||||
def classifier(text, labels):
|
||||
return f"This is the classification of {text} along {labels}."
|
||||
|
||||
|
||||
def translator(text, src_lang, tgt_lang):
|
||||
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
|
||||
|
||||
|
||||
def speaker(text):
|
||||
return f"This is actually a sound reading {text}."
|
||||
|
||||
|
||||
def transcriber(audio):
|
||||
if "sound" not in audio:
|
||||
raise ValueError(f"`audio` ({audio}) is not a sound.")
|
||||
return f"This is the transcribed text from {audio}."
|
||||
|
||||
|
||||
def image_generator(prompt):
|
||||
return f"This is actually an image representing {prompt}."
|
||||
|
||||
|
||||
def image_captioner(image):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a description of {image}."
|
||||
|
||||
|
||||
def image_transformer(image, prompt):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a transformation of {image} according to {prompt}."
|
||||
|
||||
|
||||
def question_answerer(text, question):
|
||||
return f"This is the answer to {question} from {text}."
|
||||
|
||||
|
||||
def image_qa(image, question):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is the answer to {question} from {image}."
|
||||
|
||||
|
||||
def text_downloader(url):
|
||||
return f"This is the content of {url}."
|
||||
|
||||
|
||||
def summarizer(text):
|
||||
return f"This is a summary of {text}."
|
||||
|
||||
|
||||
def video_generator(prompt, seconds=2):
|
||||
return f"A video of {prompt}"
|
||||
|
||||
|
||||
def document_qa(image, question):
|
||||
return f"This is the answer to {question} from the document {image}."
|
||||
|
||||
|
||||
def image_segmenter(image, prompt):
|
||||
return f"This is the mask of {prompt} in {image}"
|
||||
|
||||
|
||||
TEST_TOOLS = {
|
||||
"text_classifier": classifier,
|
||||
"translator": translator,
|
||||
"text_reader": speaker,
|
||||
"summarizer": summarizer,
|
||||
"transcriber": transcriber,
|
||||
"image_generator": image_generator,
|
||||
"image_captioner": image_captioner,
|
||||
"image_transformer": image_transformer,
|
||||
"text_qa": question_answerer,
|
||||
"text_downloader": text_downloader,
|
||||
"image_qa": image_qa,
|
||||
"video_generator": video_generator,
|
||||
"document_qa": document_qa,
|
||||
"image_segmenter": image_segmenter,
|
||||
}
|
||||
|
||||
|
||||
class Problem:
|
||||
"""
|
||||
A class regrouping all the information to solve a problem on which we will evaluate agents.
|
||||
|
||||
Args:
|
||||
task (`str` ou `list[str]`):
|
||||
One or several descriptions of the task to perform. If a list, it should contain variations on the
|
||||
phrasing, but for the same task.
|
||||
inputs (`list[str]` or `dict[str, str]`):
|
||||
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
|
||||
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
|
||||
inputs expected (the value used will be `<<input_name>>` in this case).
|
||||
answer (`str` or `list[str`]):
|
||||
The theoretical answer (or list of possible valid answers) to the problem, as code.
|
||||
"""
|
||||
|
||||
def __init__(self, task, inputs, answer):
|
||||
self.task = task
|
||||
self.inputs = inputs
|
||||
self.answer = answer
|
||||
|
||||
|
||||
### The list of problems the agent will be evaluated on.
|
||||
EVALUATION_TASKS = [
|
||||
Problem(
|
||||
task=[
|
||||
"Is the following `text` (in Spanish) positive or negative?",
|
||||
"Is the text in the variable `text` (in Spanish) positive or negative?",
|
||||
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Tell me out loud what the `image` contains.",
|
||||
"Describe the following `image` out loud.",
|
||||
"Find what is in the picture stored in `image` then read it out loud.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"text_reader(image_captioner(image))",
|
||||
"text_reader(image_qa(image, question='What is in the image?'))",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
|
||||
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["text_input", "prompt"],
|
||||
answer="image_transformer(image_generator(text_input), prompt)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then generate an image from its content.",
|
||||
"Use a summary of the web page at `url` to generate an image.",
|
||||
"Summarize the content of the web page at `url`, and use the result to generate an image.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="image_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
|
||||
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
|
||||
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
|
||||
],
|
||||
inputs=["text", "image"],
|
||||
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then read it out loud to me.",
|
||||
"Read me a summary of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="text_reader(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
|
||||
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
|
||||
"Read me a summary of the the `text` out loud. Transcribe this and translate it in French.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the following file `url`, summarize it in a few words and generate a video from it."
|
||||
"Fetch the file at this `url`, summarize it, and create an animation out of it."
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="video_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
EVALUATION_CHATS = [
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Is it positive or negative?",
|
||||
"Tell me if its positive or negative.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"What does this `image` contain?",
|
||||
"Describe the following `image`.",
|
||||
"Find what is in the picture stored in `image`",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"description=image_captioner(image)",
|
||||
"description=image_qa(image, question='What is in the image?')",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
"Use the following `text_input` to generate an image",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image = image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform it according to the text in `prompt`.",
|
||||
"Transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from its content.",
|
||||
"Use the previous result to generate an image.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="image_generator(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate this Spanish `text` in English.",
|
||||
"Translate the `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the translated `text`.",
|
||||
"Use the previous result to transform the following `image`.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer="image_transformer(image, translated_text)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
|
||||
inputs=["url"],
|
||||
answer="text = text_downloader(url)",
|
||||
),
|
||||
Problem(
|
||||
task=["Summarize this text.", "Summarize this text."],
|
||||
inputs=[],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read it out loud to me.", "Read me the previous result."],
|
||||
inputs=[],
|
||||
answer="text_reader(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Provide me the summary of the `text`.", "Summarize `text`."],
|
||||
inputs=["text"],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read this summary to me.", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer="audio = text_reader(summarizer(text))",
|
||||
),
|
||||
Problem(
|
||||
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
|
||||
inputs=[],
|
||||
answer="text = transcriber(audio)",
|
||||
),
|
||||
Problem(
|
||||
task=["Translating the last result in French.", "Translate this in French."],
|
||||
inputs=[],
|
||||
answer="translator(text, src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=["generate a video from it.", "Create an animation from the last result."],
|
||||
inputs=[],
|
||||
answer="video_generator(summary)",
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
|
||||
if not isinstance(theoretical_answer, list):
|
||||
return {name for name in TEST_TOOLS if name in code_answer}
|
||||
|
||||
if isinstance(agent_answer, dict):
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if one_answer in agent_answer.values():
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if agent_answer == one_answer:
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
return {name for name in TEST_TOOLS if name in code_answer[0]}
|
||||
|
||||
|
||||
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
|
||||
tools = BASE_PYTHON_TOOLS.copy()
|
||||
for name, tool in TEST_TOOLS.items():
|
||||
if name not in code:
|
||||
continue
|
||||
tools[name] = tool
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
inputs = inputs.copy()
|
||||
elif inputs is not None:
|
||||
inputs = {inp: f"<<{inp}>>" for inp in inputs}
|
||||
|
||||
if state is not None:
|
||||
state.update(inputs)
|
||||
else:
|
||||
state = inputs
|
||||
|
||||
try:
|
||||
return evaluate(code, tools, state)
|
||||
except InterpretorError as e:
|
||||
return str(e)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
|
||||
if verbose:
|
||||
print(agent_answer, theoretical_answer)
|
||||
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
|
||||
|
||||
if agent_answer in theoretical_answer:
|
||||
if verbose:
|
||||
print("Perfect!")
|
||||
return 1
|
||||
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
|
||||
if verbose:
|
||||
print("Almsot perfect, result in state!")
|
||||
return 0.75
|
||||
else:
|
||||
if verbose:
|
||||
print("Result is not the right one but code executed.")
|
||||
return 0.3
|
||||
|
||||
|
||||
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
|
||||
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
|
||||
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
|
||||
if tools_in_explanation == theoretical_tools:
|
||||
tool_selection_score = 1.0
|
||||
tool_selection_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_explanation)
|
||||
unexpected_tools = len(tools_in_explanation - theoretical_tools)
|
||||
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_selection_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
tools_in_code = {name for name in TEST_TOOLS if name in code}
|
||||
if tools_in_code == theoretical_tools:
|
||||
tool_used_score = 1.0
|
||||
tool_used_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_code)
|
||||
unexpected_tools = len(tools_in_code - theoretical_tools)
|
||||
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_used_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
|
||||
if score < 1.0:
|
||||
code_errors = {
|
||||
"code_produced": code,
|
||||
"evaluation": agent_answer,
|
||||
"theoretical_answer": theoretical_answer,
|
||||
}
|
||||
else:
|
||||
code_errors = None
|
||||
|
||||
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
|
||||
|
||||
|
||||
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_TASKS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = set(agent_tools) - TEST_TOOLS
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
eval_tasks = []
|
||||
eval_idx = []
|
||||
for idx, pb in enumerate(EVALUATION_TASKS):
|
||||
if isinstance(pb.task, list):
|
||||
eval_tasks.extend(pb.task)
|
||||
eval_idx.extend([idx] * len(pb.task))
|
||||
else:
|
||||
eval_tasks.append(pb.task)
|
||||
eval_idx.append(idx)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for start_idx in range(0, len(eval_tasks), batch_size):
|
||||
end_idx = min(start_idx + batch_size, len(eval_tasks))
|
||||
batch_tasks = eval_tasks[start_idx:end_idx]
|
||||
|
||||
prompts = [agent.format_prompt(task) for task in batch_tasks]
|
||||
results = agent.generate_many(prompts, stop=["Task:"])
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
|
||||
if verbose:
|
||||
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
|
||||
if isinstance(problem.answer, list):
|
||||
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[batch_tasks[idx]] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[batch_tasks[idx]] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[batch_tasks[idx]] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
|
||||
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
|
||||
"code score": 100 * (code_score / len(eval_tasks)),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_CHATS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = agent_tools - set(TEST_TOOLS)
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
total_steps = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for chat_problem in EVALUATION_CHATS:
|
||||
if isinstance(chat_problem[0].task, str):
|
||||
resolved_problems = [chat_problem]
|
||||
else:
|
||||
resolved_problems = [
|
||||
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
|
||||
for i in range(len(chat_problem[0].task))
|
||||
]
|
||||
for problem in resolved_problems:
|
||||
agent.prepare_for_new_chat()
|
||||
agent_state = {}
|
||||
theoretical_state = (
|
||||
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
|
||||
)
|
||||
|
||||
for step, step_problem in enumerate(problem):
|
||||
if verbose:
|
||||
print(step_problem.task)
|
||||
total_steps += 1
|
||||
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
|
||||
result = agent.generate_one(prompt, stop=["Human:", "====="])
|
||||
agent.chat_history = prompt + result + "\n"
|
||||
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
if verbose:
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
print(f"\n==Code generated by the agent==\n{code}")
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
|
||||
|
||||
answer = step_problem.answer
|
||||
if isinstance(answer, list):
|
||||
theoretical_answer = [
|
||||
evaluate_code(a, step_problem.inputs, state=state)
|
||||
for a, state in zip(answer, theoretical_state)
|
||||
]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[step_problem.task] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[step_problem.task] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[step_problem.task] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / total_steps),
|
||||
"tool used score": 100 * (tool_used_score / total_steps),
|
||||
"code score": 100 * (code_score / total_steps),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
51
src/transformers/tools/image_captioning.py
Normal file
51
src/transformers/tools/image_captioning.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..models.auto import AutoModelForVision2Seq
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningTool(PipelineTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
name = "image_captioner"
|
||||
model_class = AutoModelForVision2Seq
|
||||
|
||||
inputs = ["image"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image"):
|
||||
return self.pre_processor(images=image, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
57
src/transformers/tools/image_question_answering.py
Normal file
57
src/transformers/tools/image_question_answering.py
Normal file
@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
|
||||
"image containing the information, as well as a `question` which should be the question in English. It "
|
||||
"returns a text that is the answer to the question."
|
||||
)
|
||||
name = "image_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVisualQuestionAnswering
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
return self.pre_processor(image, question, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs).logits
|
||||
|
||||
def decode(self, outputs):
|
||||
idx = outputs.argmax(-1).item()
|
||||
return self.model.config.id2label[idx]
|
59
src/transformers/tools/image_segmentation.py
Normal file
59
src/transformers/tools/image_segmentation.py
Normal file
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..models.clipseg import CLIPSegForImageSegmentation
|
||||
from ..utils import is_vision_available, requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
description = (
|
||||
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image."
|
||||
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
|
||||
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
|
||||
)
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
name = "image_segmenter"
|
||||
model_class = CLIPSegForImageSegmentation
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["image"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", label: str):
|
||||
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
|
||||
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits
|
||||
return logits
|
||||
|
||||
def decode(self, outputs):
|
||||
array = outputs.cpu().detach().numpy()
|
||||
array[array <= 0] = 0
|
||||
array[array > 0] = 1
|
||||
return Image.fromarray((array * 255).astype(np.uint8))
|
186
src/transformers/tools/prompts.py
Normal file
186
src/transformers/tools/prompts.py
Normal file
@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# docstyle-ignore
|
||||
RUN_PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
Task: "Generate an image using the text given in the variable `caption`."
|
||||
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
image = image_generator(prompt=caption)
|
||||
```
|
||||
|
||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||
|
||||
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(summarized_text)
|
||||
```
|
||||
|
||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||
|
||||
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = text_qa(text=text, question=question)
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
Task: "Caption the following `image`."
|
||||
|
||||
I will use the following tool: `image_captioner` to generate a caption for the image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
caption = image_captioner(image)
|
||||
```
|
||||
|
||||
Task: "<<prompt>>"
|
||||
|
||||
I will use the following"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_PROMPT_TEMPLATE = """Below are a series of dialogues between various people and an AI assistant specialized in coding. The AI assistant tries to be helpful, polite, honest, and humble-but-knowledgeable.
|
||||
|
||||
The job of the AI assistant is to come up with a series of simple commands in Python that will perform the task the human wants to perform.
|
||||
To help with that, the AI assistant has access to a set of tools. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
The AI assistant should first explain the tools it will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. The AI assistant can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Identify the oldest person in the `document`.
|
||||
|
||||
Assistant: I will use the tool `document_qa` to find the oldest person in the document.
|
||||
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
```
|
||||
|
||||
Human: Can you generate an image with the result?
|
||||
|
||||
Assistant: I will use the tool `image_generator` to do that.
|
||||
|
||||
```py
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Summarize the text given in the variable `text` and read it out loud.
|
||||
|
||||
Assistant: I will use the tool `summarizer` to create a summary of the input text, then the tool `text_reader` to read it out loud.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summary)
|
||||
```
|
||||
|
||||
Human: I got the following error: "The variable `summary` is not defined."
|
||||
|
||||
Assistant: My bad! Let's try this code instead.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summarized_text)
|
||||
```
|
||||
|
||||
Human: It worked! Can you translate the summary in German?
|
||||
|
||||
Assistant: I will use the tool `translator` to translate the text in German.
|
||||
|
||||
```py
|
||||
translated_summary = translator(summarized_text, src_lang="English", tgt_lang="German)
|
||||
```
|
||||
|
||||
====
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_MESSAGE_PROMPT = """
|
||||
Human: <<task>>
|
||||
|
||||
Assistant: """
|
238
src/transformers/tools/python_interpreter.py
Normal file
238
src/transformers/tools/python_interpreter.py
Normal file
@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse through the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
chat_mode (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the function is called from `Agent.chat`.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
print("The code generated by the agent is not valid.\n", e)
|
||||
return
|
||||
if state is None:
|
||||
state = {}
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
|
||||
if chat_mode:
|
||||
msg += (
|
||||
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
|
||||
)
|
||||
else:
|
||||
msg += f":\n{e}"
|
||||
print(msg)
|
||||
break
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
|
||||
if len(var_names) == 1:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
|
||||
f"type {type(call.func)}."
|
||||
)
|
||||
func_name = call.func.id
|
||||
if func_name not in tools:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
func = tools[func_name]
|
||||
# Todo deal with args
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return value[int(index)]
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpretorError("Cannot evaluate conditions with multiple operators")
|
||||
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparator = condition.ops[0]
|
||||
right = evaluate_ast(condition.comparators[0], state, tools)
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {comparator}")
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
result = None
|
||||
if evaluate_condition(if_statement.test, state, tools):
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
41
src/transformers/tools/speech_to_text.py
Normal file
41
src/transformers/tools/speech_to_text.py
Normal file
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
name = "transcriber"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = ["audio"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt").input_features
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs=inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
70
src/transformers/tools/text_classification.py
Normal file
70
src/transformers/tools/text_classification.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextClassificationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
classifier = TextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/bart-large-mnli"
|
||||
description = (
|
||||
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
|
||||
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
|
||||
"It returns the most likely label in the list of provided `labels` for the input text."
|
||||
)
|
||||
name = "text_classifier"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
inputs = ["text", ["text"]]
|
||||
outputs = ["text"]
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
config = self.model.config
|
||||
self.entailment_id = -1
|
||||
for idx, label in config.id2label.items():
|
||||
if label.lower().startswith("entail"):
|
||||
self.entailment_id = int(idx)
|
||||
if self.entailment_id == -1:
|
||||
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
|
||||
|
||||
def encode(self, text, labels):
|
||||
self._labels = labels
|
||||
return self.pre_processor(
|
||||
[text] * len(labels),
|
||||
[f"This example is {label}" for label in labels],
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
def decode(self, outputs):
|
||||
logits = outputs.logits
|
||||
label_id = torch.argmax(logits[:, 2]).item()
|
||||
return self._labels[label_id]
|
52
src/transformers/tools/text_question_answering.py
Normal file
52
src/transformers/tools/text_question_answering.py
Normal file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
|
||||
|
||||
Can you answer this question about the text: '{question}'"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
name = "text_qa"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text: str, question: str):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return self.pre_processor(prompt, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
output_ids = self.model.generate(**inputs)
|
||||
|
||||
in_b, _ = inputs["input_ids"].shape
|
||||
out_b = output_ids.shape[0]
|
||||
|
||||
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
52
src/transformers/tools/text_summarization.py
Normal file
52
src/transformers/tools/text_summarization.py
Normal file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextSummarizationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextSummarizationTool
|
||||
|
||||
summarizer = TextSummarizationTool()
|
||||
summarizer(long_text)
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "philschmid/bart-large-cnn-samsum"
|
||||
description = (
|
||||
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
|
||||
"and returns a summary of the text."
|
||||
)
|
||||
name = "summarizer"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor(text, return_tensors="pt", truncation=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)[0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
65
src/transformers/tools/text_to_speech.py
Normal file
65
src/transformers/tools/text_to_speech.py
Normal file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from ..utils import is_datasets_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
description = (
|
||||
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
|
||||
"text to read (in English) and returns a waveform object containing the sound."
|
||||
)
|
||||
name = "text_reader"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["audio"]
|
||||
|
||||
def setup(self):
|
||||
if self.post_processor is None:
|
||||
self.post_processor = "microsoft/speecht5_hifigan"
|
||||
super().setup()
|
||||
|
||||
def encode(self, text, speaker_embeddings=None):
|
||||
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
||||
|
||||
if speaker_embeddings is None:
|
||||
if not is_datasets_available():
|
||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
||||
|
||||
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
||||
|
||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model.generate_speech(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
271
src/transformers/tools/translation.py
Normal file
271
src/transformers/tools/translation.py
Normal file
@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
LANGUAGE_CODES = {
|
||||
"Acehnese Arabic": "ace_Arab",
|
||||
"Acehnese Latin": "ace_Latn",
|
||||
"Mesopotamian Arabic": "acm_Arab",
|
||||
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
||||
"Tunisian Arabic": "aeb_Arab",
|
||||
"Afrikaans": "afr_Latn",
|
||||
"South Levantine Arabic": "ajp_Arab",
|
||||
"Akan": "aka_Latn",
|
||||
"Amharic": "amh_Ethi",
|
||||
"North Levantine Arabic": "apc_Arab",
|
||||
"Modern Standard Arabic": "arb_Arab",
|
||||
"Modern Standard Arabic Romanized": "arb_Latn",
|
||||
"Najdi Arabic": "ars_Arab",
|
||||
"Moroccan Arabic": "ary_Arab",
|
||||
"Egyptian Arabic": "arz_Arab",
|
||||
"Assamese": "asm_Beng",
|
||||
"Asturian": "ast_Latn",
|
||||
"Awadhi": "awa_Deva",
|
||||
"Central Aymara": "ayr_Latn",
|
||||
"South Azerbaijani": "azb_Arab",
|
||||
"North Azerbaijani": "azj_Latn",
|
||||
"Bashkir": "bak_Cyrl",
|
||||
"Bambara": "bam_Latn",
|
||||
"Balinese": "ban_Latn",
|
||||
"Belarusian": "bel_Cyrl",
|
||||
"Bemba": "bem_Latn",
|
||||
"Bengali": "ben_Beng",
|
||||
"Bhojpuri": "bho_Deva",
|
||||
"Banjar Arabic": "bjn_Arab",
|
||||
"Banjar Latin": "bjn_Latn",
|
||||
"Standard Tibetan": "bod_Tibt",
|
||||
"Bosnian": "bos_Latn",
|
||||
"Buginese": "bug_Latn",
|
||||
"Bulgarian": "bul_Cyrl",
|
||||
"Catalan": "cat_Latn",
|
||||
"Cebuano": "ceb_Latn",
|
||||
"Czech": "ces_Latn",
|
||||
"Chokwe": "cjk_Latn",
|
||||
"Central Kurdish": "ckb_Arab",
|
||||
"Crimean Tatar": "crh_Latn",
|
||||
"Welsh": "cym_Latn",
|
||||
"Danish": "dan_Latn",
|
||||
"German": "deu_Latn",
|
||||
"Southwestern Dinka": "dik_Latn",
|
||||
"Dyula": "dyu_Latn",
|
||||
"Dzongkha": "dzo_Tibt",
|
||||
"Greek": "ell_Grek",
|
||||
"English": "eng_Latn",
|
||||
"Esperanto": "epo_Latn",
|
||||
"Estonian": "est_Latn",
|
||||
"Basque": "eus_Latn",
|
||||
"Ewe": "ewe_Latn",
|
||||
"Faroese": "fao_Latn",
|
||||
"Fijian": "fij_Latn",
|
||||
"Finnish": "fin_Latn",
|
||||
"Fon": "fon_Latn",
|
||||
"French": "fra_Latn",
|
||||
"Friulian": "fur_Latn",
|
||||
"Nigerian Fulfulde": "fuv_Latn",
|
||||
"Scottish Gaelic": "gla_Latn",
|
||||
"Irish": "gle_Latn",
|
||||
"Galician": "glg_Latn",
|
||||
"Guarani": "grn_Latn",
|
||||
"Gujarati": "guj_Gujr",
|
||||
"Haitian Creole": "hat_Latn",
|
||||
"Hausa": "hau_Latn",
|
||||
"Hebrew": "heb_Hebr",
|
||||
"Hindi": "hin_Deva",
|
||||
"Chhattisgarhi": "hne_Deva",
|
||||
"Croatian": "hrv_Latn",
|
||||
"Hungarian": "hun_Latn",
|
||||
"Armenian": "hye_Armn",
|
||||
"Igbo": "ibo_Latn",
|
||||
"Ilocano": "ilo_Latn",
|
||||
"Indonesian": "ind_Latn",
|
||||
"Icelandic": "isl_Latn",
|
||||
"Italian": "ita_Latn",
|
||||
"Javanese": "jav_Latn",
|
||||
"Japanese": "jpn_Jpan",
|
||||
"Kabyle": "kab_Latn",
|
||||
"Jingpho": "kac_Latn",
|
||||
"Kamba": "kam_Latn",
|
||||
"Kannada": "kan_Knda",
|
||||
"Kashmiri Arabic": "kas_Arab",
|
||||
"Kashmiri Devanagari": "kas_Deva",
|
||||
"Georgian": "kat_Geor",
|
||||
"Central Kanuri Arabic": "knc_Arab",
|
||||
"Central Kanuri Latin": "knc_Latn",
|
||||
"Kazakh": "kaz_Cyrl",
|
||||
"Kabiyè": "kbp_Latn",
|
||||
"Kabuverdianu": "kea_Latn",
|
||||
"Khmer": "khm_Khmr",
|
||||
"Kikuyu": "kik_Latn",
|
||||
"Kinyarwanda": "kin_Latn",
|
||||
"Kyrgyz": "kir_Cyrl",
|
||||
"Kimbundu": "kmb_Latn",
|
||||
"Northern Kurdish": "kmr_Latn",
|
||||
"Kikongo": "kon_Latn",
|
||||
"Korean": "kor_Hang",
|
||||
"Lao": "lao_Laoo",
|
||||
"Ligurian": "lij_Latn",
|
||||
"Limburgish": "lim_Latn",
|
||||
"Lingala": "lin_Latn",
|
||||
"Lithuanian": "lit_Latn",
|
||||
"Lombard": "lmo_Latn",
|
||||
"Latgalian": "ltg_Latn",
|
||||
"Luxembourgish": "ltz_Latn",
|
||||
"Luba-Kasai": "lua_Latn",
|
||||
"Ganda": "lug_Latn",
|
||||
"Luo": "luo_Latn",
|
||||
"Mizo": "lus_Latn",
|
||||
"Standard Latvian": "lvs_Latn",
|
||||
"Magahi": "mag_Deva",
|
||||
"Maithili": "mai_Deva",
|
||||
"Malayalam": "mal_Mlym",
|
||||
"Marathi": "mar_Deva",
|
||||
"Minangkabau Arabic ": "min_Arab",
|
||||
"Minangkabau Latin": "min_Latn",
|
||||
"Macedonian": "mkd_Cyrl",
|
||||
"Plateau Malagasy": "plt_Latn",
|
||||
"Maltese": "mlt_Latn",
|
||||
"Meitei Bengali": "mni_Beng",
|
||||
"Halh Mongolian": "khk_Cyrl",
|
||||
"Mossi": "mos_Latn",
|
||||
"Maori": "mri_Latn",
|
||||
"Burmese": "mya_Mymr",
|
||||
"Dutch": "nld_Latn",
|
||||
"Norwegian Nynorsk": "nno_Latn",
|
||||
"Norwegian Bokmål": "nob_Latn",
|
||||
"Nepali": "npi_Deva",
|
||||
"Northern Sotho": "nso_Latn",
|
||||
"Nuer": "nus_Latn",
|
||||
"Nyanja": "nya_Latn",
|
||||
"Occitan": "oci_Latn",
|
||||
"West Central Oromo": "gaz_Latn",
|
||||
"Odia": "ory_Orya",
|
||||
"Pangasinan": "pag_Latn",
|
||||
"Eastern Panjabi": "pan_Guru",
|
||||
"Papiamento": "pap_Latn",
|
||||
"Western Persian": "pes_Arab",
|
||||
"Polish": "pol_Latn",
|
||||
"Portuguese": "por_Latn",
|
||||
"Dari": "prs_Arab",
|
||||
"Southern Pashto": "pbt_Arab",
|
||||
"Ayacucho Quechua": "quy_Latn",
|
||||
"Romanian": "ron_Latn",
|
||||
"Rundi": "run_Latn",
|
||||
"Russian": "rus_Cyrl",
|
||||
"Sango": "sag_Latn",
|
||||
"Sanskrit": "san_Deva",
|
||||
"Santali": "sat_Olck",
|
||||
"Sicilian": "scn_Latn",
|
||||
"Shan": "shn_Mymr",
|
||||
"Sinhala": "sin_Sinh",
|
||||
"Slovak": "slk_Latn",
|
||||
"Slovenian": "slv_Latn",
|
||||
"Samoan": "smo_Latn",
|
||||
"Shona": "sna_Latn",
|
||||
"Sindhi": "snd_Arab",
|
||||
"Somali": "som_Latn",
|
||||
"Southern Sotho": "sot_Latn",
|
||||
"Spanish": "spa_Latn",
|
||||
"Tosk Albanian": "als_Latn",
|
||||
"Sardinian": "srd_Latn",
|
||||
"Serbian": "srp_Cyrl",
|
||||
"Swati": "ssw_Latn",
|
||||
"Sundanese": "sun_Latn",
|
||||
"Swedish": "swe_Latn",
|
||||
"Swahili": "swh_Latn",
|
||||
"Silesian": "szl_Latn",
|
||||
"Tamil": "tam_Taml",
|
||||
"Tatar": "tat_Cyrl",
|
||||
"Telugu": "tel_Telu",
|
||||
"Tajik": "tgk_Cyrl",
|
||||
"Tagalog": "tgl_Latn",
|
||||
"Thai": "tha_Thai",
|
||||
"Tigrinya": "tir_Ethi",
|
||||
"Tamasheq Latin": "taq_Latn",
|
||||
"Tamasheq Tifinagh": "taq_Tfng",
|
||||
"Tok Pisin": "tpi_Latn",
|
||||
"Tswana": "tsn_Latn",
|
||||
"Tsonga": "tso_Latn",
|
||||
"Turkmen": "tuk_Latn",
|
||||
"Tumbuka": "tum_Latn",
|
||||
"Turkish": "tur_Latn",
|
||||
"Twi": "twi_Latn",
|
||||
"Central Atlas Tamazight": "tzm_Tfng",
|
||||
"Uyghur": "uig_Arab",
|
||||
"Ukrainian": "ukr_Cyrl",
|
||||
"Umbundu": "umb_Latn",
|
||||
"Urdu": "urd_Arab",
|
||||
"Northern Uzbek": "uzn_Latn",
|
||||
"Venetian": "vec_Latn",
|
||||
"Vietnamese": "vie_Latn",
|
||||
"Waray": "war_Latn",
|
||||
"Wolof": "wol_Latn",
|
||||
"Xhosa": "xho_Latn",
|
||||
"Eastern Yiddish": "ydd_Hebr",
|
||||
"Yoruba": "yor_Latn",
|
||||
"Yue Chinese": "yue_Hant",
|
||||
"Chinese Simplified": "zho_Hans",
|
||||
"Chinese Traditional": "zho_Hant",
|
||||
"Standard Malay": "zsm_Latn",
|
||||
"Zulu": "zul_Latn",
|
||||
}
|
||||
|
||||
|
||||
class TranslationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TranslationTool
|
||||
|
||||
translator = TranslationTool()
|
||||
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
||||
description = (
|
||||
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
|
||||
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
|
||||
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
|
||||
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
|
||||
)
|
||||
name = "translator"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
lang_to_code = LANGUAGE_CODES
|
||||
|
||||
inputs = ["text", "text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text, src_lang, tgt_lang):
|
||||
if src_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{src_lang} is not a supported language.")
|
||||
if tgt_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{tgt_lang} is not a supported language.")
|
||||
src_lang = self.lang_to_code[src_lang]
|
||||
tgt_lang = self.lang_to_code[tgt_lang]
|
||||
return self.pre_processor._build_translation_inputs(
|
||||
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|
@ -27,7 +27,6 @@ from .doc import (
|
||||
copy_func,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .doctest_utils import HfDocTestParser
|
||||
from .generic import (
|
||||
ContextManagers,
|
||||
ExplicitEnum,
|
||||
@ -122,6 +121,7 @@ from .import_utils import (
|
||||
is_natten_available,
|
||||
is_ninja_available,
|
||||
is_onnx_available,
|
||||
is_openai_available,
|
||||
is_optimum_available,
|
||||
is_pandas_available,
|
||||
is_peft_available,
|
||||
|
@ -1,189 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utils to run the documentation tests without having to overwrite any files.
|
||||
|
||||
The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
|
||||
made as a print would otherwise fail the corresonding line.
|
||||
|
||||
To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules <path_to_files_to_test>
|
||||
"""
|
||||
import doctest
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from typing import Iterable
|
||||
|
||||
from _pytest.doctest import (
|
||||
Module,
|
||||
_get_checker,
|
||||
_get_continue_on_failure,
|
||||
_get_runner,
|
||||
_is_mocked,
|
||||
_patch_unwrap_mock_aware,
|
||||
get_optionflags,
|
||||
import_path,
|
||||
)
|
||||
from _pytest.outcomes import skip
|
||||
from pytest import DoctestItem
|
||||
|
||||
|
||||
def preprocess_string(string, skip_cuda_tests):
|
||||
"""Prepare a docstring or a `.mdx` file to be run by doctest.
|
||||
|
||||
The argument `string` would be the whole file content if it is a `.mdx` file. For a python file, it would be one of
|
||||
its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
|
||||
cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
|
||||
`string`.
|
||||
"""
|
||||
codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)"
|
||||
codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string)
|
||||
is_cuda_found = False
|
||||
for i, codeblock in enumerate(codeblocks):
|
||||
if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
|
||||
codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
|
||||
if (
|
||||
(">>>" in codeblock or "..." in codeblock)
|
||||
and re.search(r"cuda|to\(0\)|device=0", codeblock)
|
||||
and skip_cuda_tests
|
||||
):
|
||||
is_cuda_found = True
|
||||
break
|
||||
modified_string = ""
|
||||
if not is_cuda_found:
|
||||
modified_string = "".join(codeblocks)
|
||||
return modified_string
|
||||
|
||||
|
||||
class HfDocTestParser(doctest.DocTestParser):
|
||||
"""
|
||||
Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
|
||||
means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
|
||||
added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
|
||||
|
||||
Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
|
||||
"""
|
||||
|
||||
# This regular expression is used to find doctest examples in a
|
||||
# string. It defines three groups: `source` is the source code
|
||||
# (including leading indentation and prompts); `indent` is the
|
||||
# indentation of the first (PS1) line of the source code; and
|
||||
# `want` is the expected output (including leading indentation).
|
||||
# fmt: off
|
||||
_EXAMPLE_RE = re.compile(r'''
|
||||
# Source consists of a PS1 line followed by zero or more PS2 lines.
|
||||
(?P<source>
|
||||
(?:^(?P<indent> [ ]*) >>> .*) # PS1 line
|
||||
(?:\n [ ]* \.\.\. .*)*) # PS2 lines
|
||||
\n?
|
||||
# Want consists of any non-blank lines that do not start with PS1.
|
||||
(?P<want> (?:(?![ ]*$) # Not a blank line
|
||||
(?![ ]*>>>) # Not a line starting with PS1
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
(?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
(?:\n|$) # Match a new line or end of string
|
||||
)*)
|
||||
''', re.MULTILINE | re.VERBOSE
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False))
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
|
||||
def parse(self, string, name="<string>"):
|
||||
"""
|
||||
Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
|
||||
calling `super().parse`
|
||||
"""
|
||||
string = preprocess_string(string, self.skip_cuda_tests)
|
||||
return super().parse(string, name)
|
||||
|
||||
|
||||
class HfDoctestModule(Module):
|
||||
"""
|
||||
Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
|
||||
tests.
|
||||
"""
|
||||
|
||||
def collect(self) -> Iterable[DoctestItem]:
|
||||
class MockAwareDocTestFinder(doctest.DocTestFinder):
|
||||
"""A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
|
||||
|
||||
https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
|
||||
"""
|
||||
|
||||
def _find_lineno(self, obj, source_lines):
|
||||
"""Doctest code does not take into account `@property`, this
|
||||
is a hackish way to fix it. https://bugs.python.org/issue17446
|
||||
|
||||
Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
|
||||
reported upstream. #8796
|
||||
"""
|
||||
if isinstance(obj, property):
|
||||
obj = getattr(obj, "fget", obj)
|
||||
|
||||
if hasattr(obj, "__wrapped__"):
|
||||
# Get the main obj in case of it being wrapped
|
||||
obj = inspect.unwrap(obj)
|
||||
|
||||
# Type ignored because this is a private function.
|
||||
return super()._find_lineno( # type:ignore[misc]
|
||||
obj,
|
||||
source_lines,
|
||||
)
|
||||
|
||||
def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
|
||||
if _is_mocked(obj):
|
||||
return
|
||||
with _patch_unwrap_mock_aware():
|
||||
# Type ignored because this is a private function.
|
||||
super()._find( # type:ignore[misc]
|
||||
tests, obj, name, module, source_lines, globs, seen
|
||||
)
|
||||
|
||||
if self.path.name == "conftest.py":
|
||||
module = self.config.pluginmanager._importconftest(
|
||||
self.path,
|
||||
self.config.getoption("importmode"),
|
||||
rootpath=self.config.rootpath,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
module = import_path(
|
||||
self.path,
|
||||
root=self.config.rootpath,
|
||||
mode=self.config.getoption("importmode"),
|
||||
)
|
||||
except ImportError:
|
||||
if self.config.getvalue("doctest_ignore_import_errors"):
|
||||
skip("unable to import module %r" % self.path)
|
||||
else:
|
||||
raise
|
||||
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
finder = MockAwareDocTestFinder(parser=HfDocTestParser())
|
||||
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
|
||||
optionflags = get_optionflags(self)
|
||||
runner = _get_runner(
|
||||
verbose=False,
|
||||
optionflags=optionflags,
|
||||
checker=_get_checker(),
|
||||
continue_on_failure=_get_continue_on_failure(self.config),
|
||||
)
|
||||
for test in finder.find(module, module.__name__):
|
||||
if test.examples: # skip empty doctests and cuda
|
||||
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
|
@ -235,6 +235,7 @@ def try_to_load_from_cache(
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision if found.
|
||||
@ -251,6 +252,8 @@ def try_to_load_from_cache(
|
||||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
repo_type (`str`, *optional*):
|
||||
The type of the repo.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
@ -266,7 +269,9 @@ def try_to_load_from_cache(
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
|
||||
object_id = repo_id.replace("/", "--")
|
||||
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
|
||||
if repo_type is None:
|
||||
repo_type = "model"
|
||||
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
||||
if not os.path.isdir(repo_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
@ -303,6 +308,7 @@ def cached_file(
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
@ -342,6 +348,8 @@ def cached_file(
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -393,7 +401,7 @@ def cached_file(
|
||||
if _commit_hash is not None and not force_download:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
@ -410,6 +418,7 @@ def cached_file(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
|
@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_diffusers_available = importlib.util.find_spec("diffusers") is not None
|
||||
try:
|
||||
_diffusers_version = importlib_metadata.version("diffusers")
|
||||
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_diffusers_available = False
|
||||
|
||||
|
||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
||||
try:
|
||||
_detectron2_version = importlib_metadata.version("detectron2")
|
||||
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_opencv_available = importlib.util.find_spec("cv2") is not None
|
||||
|
||||
|
||||
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
|
||||
try:
|
||||
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
|
||||
@ -431,6 +442,10 @@ def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_openai_available():
|
||||
return importlib.util.find_spec("openai") is not None
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
57
tests/tools/test_document_question_answering.py
Normal file
57
tests/tools/test_document_question_answering.py
Normal file
@ -0,0 +1,57 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("document-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
53
tests/tools/test_image_captioning.py
Normal file
53
tests/tools/test_image_captioning.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-captioning")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-captioning", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
53
tests/tools/test_image_question_answering.py
Normal file
53
tests/tools/test_image_question_answering.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
53
tests/tools/test_image_segmentation.py
Normal file
53
tests/tools/test_image_segmentation.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-segmentation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
124
tests/tools/test_python_interpreter.py
Normal file
124
tests/tools/test_python_interpreter.py
Normal file
@ -0,0 +1,124 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
# Won't work without the tool
|
||||
with CaptureStdout() as out:
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result is None
|
||||
assert "tried to execute add_two" in out.out
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
38
tests/tools/test_speech_to_text.py
Normal file
38
tests/tools/test_speech_to_text.py
Normal file
@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
43
tests/tools/test_text_classification.py
Normal file
43
tests/tools/test_text_classification.py
Normal file
@ -0,0 +1,43 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-classification")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-classification", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
52
tests/tools/test_text_question_answering.py
Normal file
52
tests/tools/test_text_question_answering.py
Normal file
@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
64
tests/tools/test_text_summarization.py
Normal file
64
tests/tools/test_text_summarization.py
Normal file
@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("summarization")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("summarization", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
54
tests/tools/test_text_to_speech.py
Normal file
54
tests/tools/test_text_to_speech.py
Normal file
@ -0,0 +1,54 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-to-speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
100
tests/tools/test_tools_common.py
Normal file
100
tests/tools/test_tools_common.py
Normal file
@ -0,0 +1,100 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, str):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, Image.Image):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
@is_tool_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
53
tests/tools/test_translation.py
Normal file
53
tests/tools/test_translation.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
148
utils/prepare_for_doc_test.py
Normal file
148
utils/prepare_for_doc_test.py
Normal file
@ -0,0 +1,148 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Style utils to preprocess files for doc tests.
|
||||
|
||||
The doc precossing function can be run on a list of files and/org
|
||||
directories of files. It will recursively check if the files have
|
||||
a python code snippet by looking for a ```python or ```py syntax.
|
||||
In the default mode - `remove_new_line==False` the script will
|
||||
add a new line before every python code ending ``` line to make
|
||||
the docstrings ready for pytest doctests.
|
||||
However, we don't want to have empty lines displayed in the
|
||||
official documentation which is why the new line command can be
|
||||
reversed by adding the flag `--remove_new_line` which sets
|
||||
`remove_new_line==True`.
|
||||
|
||||
When debugging the doc tests locally, please make sure to
|
||||
always run:
|
||||
|
||||
```python utils/prepare_for_doc_test.py src docs```
|
||||
|
||||
before running the doc tests:
|
||||
|
||||
```pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"```
|
||||
|
||||
Afterwards you should revert the changes by running
|
||||
|
||||
```python utils/prepare_for_doc_test.py src docs --remove_new_line```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def process_code_block(code, add_new_line=True):
|
||||
if add_new_line:
|
||||
return maybe_append_new_line(code)
|
||||
else:
|
||||
return maybe_remove_new_line(code)
|
||||
|
||||
|
||||
def maybe_append_new_line(code):
|
||||
"""
|
||||
Append new line if code snippet is a
|
||||
Python code snippet
|
||||
"""
|
||||
lines = code.split("\n")
|
||||
|
||||
if lines[0] in ["py", "python"]:
|
||||
# add new line before last line being ```
|
||||
last_line = lines[-1]
|
||||
lines.pop()
|
||||
lines.append("\n" + last_line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def maybe_remove_new_line(code):
|
||||
"""
|
||||
Remove new line if code snippet is a
|
||||
Python code snippet
|
||||
"""
|
||||
lines = code.split("\n")
|
||||
|
||||
if lines[0] in ["py", "python"]:
|
||||
# add new line before last line being ```
|
||||
lines = lines[:-2] + lines[-1:]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def process_doc_file(code_file, add_new_line=True):
|
||||
"""
|
||||
Process given file.
|
||||
|
||||
Args:
|
||||
code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
|
||||
"""
|
||||
with open(code_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
code = f.read()
|
||||
|
||||
# fmt: off
|
||||
splits = code.split("```")
|
||||
if len(splits) % 2 != 1:
|
||||
raise ValueError("The number of occurrences of ``` should be an even number.")
|
||||
|
||||
splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)]
|
||||
clean_code = "```".join(splits)
|
||||
# fmt: on
|
||||
|
||||
diff = clean_code != code
|
||||
if diff:
|
||||
print(f"Overwriting content of {code_file}.")
|
||||
with open(code_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(clean_code)
|
||||
|
||||
|
||||
def process_doc_files(*files, add_new_line=True):
|
||||
"""
|
||||
Applies doc styling or checks everything is correct in a list of files.
|
||||
|
||||
Args:
|
||||
files (several `str` or `os.PathLike`): The files to treat.
|
||||
Whether to restyle file or just check if they should be restyled.
|
||||
|
||||
Returns:
|
||||
List[`str`]: The list of files changed or that should be restyled.
|
||||
"""
|
||||
for file in files:
|
||||
# Treat folders
|
||||
if os.path.isdir(file):
|
||||
files = [os.path.join(file, f) for f in os.listdir(file)]
|
||||
files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
|
||||
process_doc_files(*files, add_new_line=add_new_line)
|
||||
else:
|
||||
try:
|
||||
process_doc_file(file, add_new_line=add_new_line)
|
||||
except Exception:
|
||||
print(f"There is a problem in {file}.")
|
||||
raise
|
||||
|
||||
|
||||
def main(*files, add_new_line=True):
|
||||
process_doc_files(*files, add_new_line=add_new_line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
|
||||
parser.add_argument(
|
||||
"--remove_new_line",
|
||||
action="store_true",
|
||||
help="Whether to remove new line after each python code block instead of adding one.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(*args.files, add_new_line=not args.remove_new_line)
|
Reference in New Issue
Block a user