mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
83 Commits
default-fa
...
v4.43.3
Author | SHA1 | Date | |
---|---|---|---|
47c29ccfaf | |||
54bc29c1ba | |||
cc75146d0e | |||
cd06184cc4 | |||
38d94bffa6 | |||
b4a0442dbd | |||
4672b4d79b | |||
a2b6a001c0 | |||
64a90d72a8 | |||
782bfffb2e | |||
cf0534913f | |||
7fa7508dad | |||
26b179c90d | |||
7d92009af6 | |||
63700628ad | |||
a009fbdab3 | |||
3263b34354 | |||
034b477847 | |||
bab32d6fe9 | |||
9ced33ca7f | |||
a5b226ce98 | |||
a1844a3209 | |||
2e113422b3 | |||
5a4a76edb7 | |||
1535a2c93d | |||
34b43211d7 | |||
7405c1c77e | |||
605f3245dc | |||
2782aadae2 | |||
f83c6f1d02 | |||
3aefb4ec7f | |||
251a2409c6 | |||
96a074fa7e | |||
bd9dca3b85 | |||
817a676bd7 | |||
74d0eb3fed | |||
7987710696 | |||
12b6880c81 | |||
d1ec36b94f | |||
7ba028fccb | |||
5a649ff3ec | |||
f2a1e3ca68 | |||
0fcfc5ccc9 | |||
c38c55f4fb | |||
aa8f86a421 | |||
b381880597 | |||
0fdea8607d | |||
fe008d6ebe | |||
62aa270f2a | |||
89575b567e | |||
46835ec6ae | |||
4bd8f12972 | |||
566b0f1fbf | |||
e316c5214f | |||
22f888b3fa | |||
cd48553fc8 | |||
56a7745704 | |||
b873234cb6 | |||
271fd8e60d | |||
8f0d26c55e | |||
c75969ee28 | |||
4c040aba02 | |||
c50e0551fd | |||
c25dde1fc9 | |||
673d30b826 | |||
765732e92c | |||
1c37e8c1a6 | |||
b31d595040 | |||
cb23d1b20b | |||
bc36c26fa6 | |||
63be8e6f39 | |||
72fb02c47d | |||
691586b0dc | |||
24cfcc2114 | |||
4037a2b5b1 | |||
6f40a213eb | |||
e391706420 | |||
c22efa6196 | |||
88e0813d8d | |||
036d3de23d | |||
89eec5cf20 | |||
999981daf4 | |||
693cb828ff |
17
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
17
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -1,6 +1,17 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve transformers
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report! 🤗
|
||||
|
||||
Before you submit your bug report:
|
||||
|
||||
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
- Try our [docs bot](https://huggingface.co/spaces/huggingchat/hf-docs-chat) -- it might be able to help you with your issue
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
@ -25,7 +36,7 @@ body:
|
||||
|
||||
Models:
|
||||
|
||||
- text models: @ArthurZucker
|
||||
- text models: @ArthurZucker
|
||||
- vision models: @amyeroberts
|
||||
- speech models: @sanchit-gandhi
|
||||
- graph models: @clefourrier
|
||||
@ -38,9 +49,9 @@ body:
|
||||
- tensorflow: @gante and @Rocketknight1
|
||||
- tokenizers: @ArthurZucker
|
||||
- trainer: @muellerzr @SunMarc
|
||||
|
||||
|
||||
Integrations:
|
||||
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
|
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -58,9 +58,9 @@ Integrations:
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc
|
||||
|
||||
Documentation: @stevhliu and @MKhalusova
|
||||
Documentation: @stevhliu
|
||||
|
||||
HF projects:
|
||||
|
||||
|
@ -61,7 +61,10 @@ feedback.
|
||||
The 🤗 Transformers library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. If you're unsure whether the bug is in your code or the library, please ask in the [forum](https://discuss.huggingface.co/) first. This helps us respond quicker to fixing issues related to the library versus general questions.
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. If you're unsure whether the bug is in your code or the library, please ask in the [forum](https://discuss.huggingface.co/) or on our [discord](https://discord.com/invite/hugging-face-879548962464493619) first. This helps us respond quicker to fixing issues related to the library versus general questions.
|
||||
|
||||
> [!TIP]
|
||||
> We have a [docs bot](https://huggingface.co/spaces/huggingchat/hf-docs-chat), and we highly encourage you to ask all your questions there. There is always a chance your bug can be fixed with a simple flag 👾🔫
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
@ -129,7 +132,7 @@ You will need basic `git` proficiency to contribute to
|
||||
manual. Type `git --help` in a shell and enjoy! If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
You'll need **[Python 3.8](https://github.com/huggingface/transformers/blob/main/setup.py#L426)** or above to contribute to 🤗 Transformers. Follow the steps below to start contributing:
|
||||
You'll need **[Python 3.8](https://github.com/huggingface/transformers/blob/main/setup.py#L449)** or above to contribute to 🤗 Transformers. Follow the steps below to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/transformers) by
|
||||
clicking on the **[Fork](https://github.com/huggingface/transformers/fork)** button on the repository's page. This creates a copy of the code
|
||||
@ -160,7 +163,7 @@ You'll need **[Python 3.8](https://github.com/huggingface/transformers/blob/main
|
||||
If 🤗 Transformers was already installed in the virtual environment, remove
|
||||
it with `pip uninstall transformers` before reinstalling it in editable
|
||||
mode with the `-e` flag.
|
||||
|
||||
|
||||
Depending on your OS, and since the number of optional dependencies of Transformers is growing, you might get a
|
||||
failure with this command. If that's the case make sure to install the Deep Learning framework you are working with
|
||||
(PyTorch, TensorFlow and/or Flax) then do:
|
||||
@ -219,7 +222,7 @@ You'll need **[Python 3.8](https://github.com/huggingface/transformers/blob/main
|
||||
|
||||
If you're modifying documents under the `docs/source` directory, make sure the documentation can still be built. This check will also run in the CI when you open a pull request. To run a local check
|
||||
make sure you install the documentation builder:
|
||||
|
||||
|
||||
```bash
|
||||
pip install ".[docs]"
|
||||
```
|
||||
@ -338,12 +341,12 @@ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/models/my_ne
|
||||
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/text-classification
|
||||
```
|
||||
|
||||
Like the slow tests, there are other environment variables available which not enabled by default during testing:
|
||||
Like the slow tests, there are other environment variables available which are not enabled by default during testing:
|
||||
- `RUN_CUSTOM_TOKENIZERS`: Enables tests for custom tokenizers.
|
||||
- `RUN_PT_FLAX_CROSS_TESTS`: Enables tests for PyTorch + Flax integration.
|
||||
- `RUN_PT_TF_CROSS_TESTS`: Enables tests for TensorFlow + PyTorch integration.
|
||||
|
||||
More environment variables and additional information can be found in the [testing_utils.py](src/transformers/testing_utils.py).
|
||||
More environment variables and additional information can be found in the [testing_utils.py](https://github.com/huggingface/transformers/blob/main/src/transformers/testing_utils.py).
|
||||
|
||||
🤗 Transformers uses `pytest` as a test runner only. It doesn't use any
|
||||
`pytest`-specific features in the test suite itself.
|
||||
|
@ -92,6 +92,8 @@
|
||||
title: Visual Question Answering
|
||||
- local: tasks/text-to-speech
|
||||
title: Text to speech
|
||||
- local: tasks/image_text_to_text
|
||||
title: Image-text-to-text
|
||||
title: Multimodal
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@ -155,6 +157,8 @@
|
||||
title: EETQ
|
||||
- local: quantization/hqq
|
||||
title: HQQ
|
||||
- local: quantization/fbgemm_fp8
|
||||
title: FBGEMM_FP8
|
||||
- local: quantization/optimum
|
||||
title: Optimum
|
||||
- local: quantization/contribute
|
||||
@ -758,6 +762,8 @@
|
||||
title: BridgeTower
|
||||
- local: model_doc/bros
|
||||
title: BROS
|
||||
- local: model_doc/chameleon
|
||||
title: Chameleon
|
||||
- local: model_doc/chinese_clip
|
||||
title: Chinese-CLIP
|
||||
- local: model_doc/clip
|
||||
|
@ -88,6 +88,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [ByT5](model_doc/byt5) | ✅ | ✅ | ✅ |
|
||||
| [CamemBERT](model_doc/camembert) | ✅ | ✅ | ❌ |
|
||||
| [CANINE](model_doc/canine) | ✅ | ❌ | ❌ |
|
||||
| [Chameleon](model_doc/chameleon) | ✅ | ❌ | ❌ |
|
||||
| [Chinese-CLIP](model_doc/chinese_clip) | ✅ | ❌ | ❌ |
|
||||
| [CLAP](model_doc/clap) | ✅ | ❌ | ❌ |
|
||||
| [CLIP](model_doc/clip) | ✅ | ✅ | ✅ |
|
||||
|
@ -25,11 +25,11 @@ A backbone is a model used for feature extraction for higher level computer visi
|
||||
|
||||
Backbones are supported for the following models:
|
||||
|
||||
* [BEiT](..model_doc/beit)
|
||||
* [BEiT](../model_doc/beit)
|
||||
* [BiT](../model_doc/bit)
|
||||
* [ConvNet](../model_doc/convnext)
|
||||
* [ConvNext](../model_doc/convnext)
|
||||
* [ConvNextV2](../model_doc/convnextv2)
|
||||
* [DiNAT](..model_doc/dinat)
|
||||
* [DiNAT](../model_doc/dinat)
|
||||
* [DINOV2](../model_doc/dinov2)
|
||||
* [FocalNet](../model_doc/focalnet)
|
||||
* [MaskFormer](../model_doc/maskformer)
|
||||
|
@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## HqqConfig
|
||||
|
||||
[[autodoc]] HqqConfig
|
||||
|
||||
## FbgemmFp8Config
|
||||
|
||||
[[autodoc]] FbgemmFp8Config
|
||||
|
||||
|
192
docs/source/en/model_doc/chameleon.md
Normal file
192
docs/source/en/model_doc/chameleon.md
Normal file
@ -0,0 +1,192 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Chameleon
|
||||
|
||||
## Overview
|
||||
|
||||
The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models
|
||||
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present Chameleon, a family of early-fusion token-based mixed-modal models capable of understanding and generating images and text in any arbitrary sequence. We outline a stable training
|
||||
approach from inception, an alignment recipe, and an architectural parameterization tailored for the
|
||||
early-fusion, token-based, mixed-modal setting. The models are evaluated on a comprehensive range
|
||||
of tasks, including visual question answering, image captioning, text generation, image generation, and
|
||||
long-form mixed modal generation. Chameleon demonstrates broad and general capabilities, including
|
||||
state-of-the-art performance in image captioning tasks, outperforms Llama-2 in text-only tasks while
|
||||
being competitive with models such as Mixtral 8x7B and Gemini-Pro, and performs non-trivial image
|
||||
generation, all in a single model. It also matches or exceeds the performance of much larger models,
|
||||
including Gemini Pro and GPT-4V, according to human judgments on a new long-form mixed-modal
|
||||
generation evaluation, where either the prompt or outputs contain mixed sequences of both images and
|
||||
text. Chameleon marks a significant step forward in unified modeling of full multimodal documents*
|
||||
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/chameleon_arch.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Chameleon incorporates a vector quantizer module to transform images into discrete tokens. That also enables image generation using an auto-regressive transformer. Taken from the <a href="https://arxiv.org/abs/2405.09818v1">original paper.</a> </small>
|
||||
|
||||
This model was contributed by [joaogante](https://huggingface.co/joaogante) and [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||||
The original code can be found [here](https://github.com/facebookresearch/chameleon).
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question.
|
||||
|
||||
- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor.
|
||||
|
||||
> [!NOTE]
|
||||
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.
|
||||
|
||||
## Usage example
|
||||
|
||||
### Single image inference
|
||||
|
||||
Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token.
|
||||
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`):
|
||||
|
||||
```python
|
||||
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
|
||||
# prepare image and text prompt
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "What do you see in this image?<image>"
|
||||
|
||||
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
|
||||
|
||||
# autoregressively complete prompt
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Multi image inference
|
||||
|
||||
Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:
|
||||
|
||||
```python
|
||||
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
||||
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
|
||||
# Get three different images
|
||||
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
image_stop = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image_cats = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
image_snowman = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Prepare a batched prompt, where the first one is a multi-image prompt and the second is not
|
||||
prompts = [
|
||||
"What do these images have in common?<image><image>",
|
||||
"<image>What is shown in this image?"
|
||||
]
|
||||
|
||||
# We can simply feed images in the order they have to be used in the text prompt
|
||||
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
|
||||
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Generate
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=50)
|
||||
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
```
|
||||
|
||||
## Model optimization
|
||||
|
||||
### Quantization using Bitsandbytes
|
||||
|
||||
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
|
||||
|
||||
```python
|
||||
from transformers import ChameleonForConditionalGeneration, BitsAndBytesConfig
|
||||
|
||||
# specify how to quantize the model
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", quantization_config=quantization_config, device_map="cuda")
|
||||
```
|
||||
|
||||
### Use Flash-Attention 2 and SDPA to further speed-up generation
|
||||
|
||||
The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
|
||||
|
||||
```python
|
||||
from transformers import ChameleonForConditionalGeneration
|
||||
|
||||
model_id = "facebook/chameleon-7b"
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="flash_attention_2"
|
||||
).to(0)
|
||||
```
|
||||
|
||||
## ChameleonConfig
|
||||
|
||||
[[autodoc]] ChameleonConfig
|
||||
|
||||
## ChameleonVQVAEConfig
|
||||
|
||||
[[autodoc]] ChameleonVQVAEConfig
|
||||
|
||||
## ChameleonProcessor
|
||||
|
||||
[[autodoc]] ChameleonProcessor
|
||||
|
||||
## ChameleonImageProcessor
|
||||
|
||||
[[autodoc]] ChameleonImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ChameleonVQVAE
|
||||
|
||||
[[autodoc]] ChameleonVQVAE
|
||||
- forward
|
||||
|
||||
## ChameleonModel
|
||||
|
||||
[[autodoc]] ChameleonModel
|
||||
- forward
|
||||
|
||||
## ChameleonForConditionalGeneration
|
||||
|
||||
[[autodoc]] ChameleonForConditionalGeneration
|
||||
- forward
|
@ -79,6 +79,123 @@ encode the text and prepare the images. The following example shows how to get t
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```
|
||||
|
||||
|
||||
### Combining CLIP and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
For small batch sizes, you might notice a slowdown in your model when using flash attention. Refer to the section [Expected speedups with Flash Attention and SDPA](#Expected-speedups-with-Flash-Attention-and-SDPA) below and select an appropriate attention implementation.
|
||||
|
||||
</Tip>
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> torch_dtype = torch.float16
|
||||
|
||||
>>> model = CLIPModel.from_pretrained(
|
||||
... "openai/clip-vit-base-patch32",
|
||||
... attn_implementation="flash_attention_2",
|
||||
... device_map=device,
|
||||
... torch_dtype=torch_dtype,
|
||||
... )
|
||||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
||||
>>> inputs.to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... with torch.autocast(device):
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
>>> print(probs)
|
||||
tensor([[0.9946, 0.0052]], device='cuda:0', dtype=torch.float16)
|
||||
```
|
||||
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```python
|
||||
from transformers import CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
### Expected speedups with Flash Attention and SDPA
|
||||
|
||||
On a local benchmark (NVIDIA A10G, PyTorch 2.3.1+cu121) with `float16`, we saw the following speedups during inference for `"openai/clip-vit-large-patch14"` checkpoint ([code](https://gist.github.com/qubvel/ac691a54e54f9fae8144275f866a7ff8)):
|
||||
|
||||
#### CLIPTextModel
|
||||
|
||||
| Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 4 | 0.009 | 0.012 | 0.737 | 0.007 | 1.269 |
|
||||
| 16 | 0.009 | 0.014 | 0.659 | 0.008 | 1.187 |
|
||||
| 32 | 0.018 | 0.021 | 0.862 | 0.016 | 1.142 |
|
||||
| 64 | 0.034 | 0.034 | 1.001 | 0.03 | 1.163 |
|
||||
| 128 | 0.063 | 0.058 | 1.09 | 0.054 | 1.174 |
|
||||
|
||||

|
||||
|
||||
#### CLIPVisionModel
|
||||
|
||||
| Image batch size | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 0.016 | 0.013 | 1.247 | 0.012 | 1.318 |
|
||||
| 4 | 0.025 | 0.021 | 1.198 | 0.021 | 1.202 |
|
||||
| 16 | 0.093 | 0.075 | 1.234 | 0.075 | 1.24 |
|
||||
| 32 | 0.181 | 0.147 | 1.237 | 0.146 | 1.241 |
|
||||
|
||||

|
||||
|
||||
#### CLIPModel
|
||||
|
||||
| Image batch size | Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 4 | 0.025 | 0.026 | 0.954 | 0.02 | 1.217 |
|
||||
| 1 | 16 | 0.026 | 0.028 | 0.918 | 0.02 | 1.287 |
|
||||
| 1 | 64 | 0.042 | 0.046 | 0.906 | 0.036 | 1.167 |
|
||||
| 4 | 4 | 0.028 | 0.033 | 0.849 | 0.024 | 1.189 |
|
||||
| 4 | 16 | 0.034 | 0.035 | 0.955 | 0.029 | 1.169 |
|
||||
| 4 | 64 | 0.059 | 0.055 | 1.072 | 0.05 | 1.179 |
|
||||
| 16 | 4 | 0.096 | 0.088 | 1.091 | 0.078 | 1.234 |
|
||||
| 16 | 16 | 0.102 | 0.09 | 1.129 | 0.083 | 1.224 |
|
||||
| 16 | 64 | 0.127 | 0.11 | 1.157 | 0.105 | 1.218 |
|
||||
| 32 | 4 | 0.185 | 0.159 | 1.157 | 0.149 | 1.238 |
|
||||
| 32 | 16 | 0.19 | 0.162 | 1.177 | 0.154 | 1.233 |
|
||||
| 32 | 64 | 0.216 | 0.181 | 1.19 | 0.176 | 1.228 |
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP.
|
||||
|
@ -57,7 +57,7 @@ print((last_hidden_states - traced_outputs[0]).abs().max())
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DPT.
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DINOv2.
|
||||
|
||||
- Demo notebooks for DINOv2 can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DINOv2). 🌎
|
||||
|
||||
|
@ -26,8 +26,22 @@ The abstract from the paper is the following:
|
||||
|
||||
*Modern hierarchical vision transformers have added several vision-specific components in the pursuit of supervised classification performance. While these components lead to effective accuracies and attractive FLOP counts, the added complexity actually makes these transformers slower than their vanilla ViT counterparts. In this paper, we argue that this additional bulk is unnecessary. By pretraining with a strong visual pretext task (MAE), we can strip out all the bells-and-whistles from a state-of-the-art multi-stage vision transformer without losing accuracy. In the process, we create Hiera, an extremely simple hierarchical vision transformer that is more accurate than previous models while being significantly faster both at inference and during training. We evaluate Hiera on a variety of tasks for image and video recognition. Our code and models are available at https://github.com/facebookresearch/hiera.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/hiera_overview.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Hiera architecture. Taken from the <a href="https://arxiv.org/abs/2306.00989">original paper.</a> </small>
|
||||
|
||||
This model was a joint contibution by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [namangarg110](https://huggingface.co/namangarg110). The original code can be found [here] (https://github.com/facebookresearch/hiera).
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Hiera. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
|
||||
- [`HieraForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
|
||||
## HieraConfig
|
||||
|
||||
[[autodoc]] HieraConfig
|
||||
|
@ -43,6 +43,13 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.
|
||||
|
||||
We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:
|
||||
|
@ -40,7 +40,42 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
|
||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||
|
||||
- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint:
|
||||
- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What’s shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
|
||||
},
|
||||
{
|
||||
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image in more details."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
|
||||
print(text_prompt)
|
||||
>>> "USER: <image>\n<What’s shown in this image? ASSISTANT: This image shows a red stop sign.</s>USER: Describe the image in more details. ASSISTANT:"
|
||||
```
|
||||
|
||||
- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by each llava checkpoint:
|
||||
|
||||
[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format:
|
||||
```bash
|
||||
@ -64,6 +99,7 @@ For multiple turns conversation:
|
||||
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
|
||||
```
|
||||
|
||||
|
||||
### Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one).
|
||||
|
@ -46,26 +46,79 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. Below, we list the correct prompt formats to use for the text prompt "What is shown in this image?":
|
||||
<Tip warning={true}>
|
||||
|
||||
- Llava-Next uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.
|
||||
|
||||
We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
|
||||
|
||||
```python
|
||||
from transformers import LlavaNextProcessor
|
||||
|
||||
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What’s shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
|
||||
},
|
||||
{
|
||||
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image in more details."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
|
||||
print(text_prompt)
|
||||
>>> "[INST] <image>\nWhat's shown in this image? [/INST] This image shows a red stop sign. [INST] Describe the image in more details. [/INST]"
|
||||
```
|
||||
|
||||
- If you want to construct a chat prompt yourself, below is a list of possible formats
|
||||
.
|
||||
[llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) requires the following format:
|
||||
|
||||
```bash
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||
```
|
||||
|
||||
[llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) and [llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) require the following format:
|
||||
|
||||
```bash
|
||||
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
|
||||
```
|
||||
|
||||
[llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) requires the following format:
|
||||
|
||||
```bash
|
||||
"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
|
||||
```
|
||||
|
||||
[llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llava-next-8b-hf) requires the following format:
|
||||
|
||||
```bash
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
```
|
||||
|
||||
[llava-next-72b-hf](https://huggingface.co/llava-hf/llava-next-72b-hf) and [llava-next-110b-hf](https://huggingface.co/llava-hf/llava-next-110b-hf) require the following format:
|
||||
|
||||
```bash
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
|
||||
```
|
||||
|
||||
## Usage example
|
||||
|
||||
### Single image inference
|
||||
@ -86,8 +139,17 @@ model.to("cuda:0")
|
||||
# prepare image and text prompt, using the appropriate prompt template
|
||||
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
|
||||
|
||||
# autoregressively complete prompt
|
||||
@ -120,15 +182,47 @@ image_cats = Image.open(requests.get(url, stream=True).raw)
|
||||
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
image_snowman = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not
|
||||
prompt = [
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] <image>\nWhat about this image? How many cats do you see [/INST]",
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||
# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not
|
||||
conversation_1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "There is a red stop sign in the image."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What about this image? How many cats do you see?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
conversation_2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
|
||||
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
|
||||
prompts = [prompt_1, prompt_2]
|
||||
|
||||
# We can simply feed images in the order they have to be used in the text prompt
|
||||
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
|
||||
inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
|
||||
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=30)
|
||||
|
@ -105,7 +105,7 @@ from huggingface_hub import list_models
|
||||
|
||||
model_list = list_models()
|
||||
org = "Helsinki-NLP"
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||
model_ids = [x.id for x in model_list if x.id.startswith(org)]
|
||||
suffix = [x.split("/")[1] for x in model_ids]
|
||||
old_style_multi_models = [f"{org}/{s}" for s in suffix if s != s.lower()]
|
||||
```
|
||||
|
@ -51,19 +51,19 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
|
||||
|
||||
## Usage tips
|
||||
|
||||
- This implementation is the same as [`BertModel`] with a tiny embeddings tweak as well as a setup
|
||||
for Roberta pretrained models.
|
||||
- RoBERTa has the same architecture as BERT, but uses a byte-level BPE as a tokenizer (same as GPT-2) and uses a
|
||||
- This implementation is the same as [`BertModel`] with a minor tweak to the embeddings, as well as a setup
|
||||
for RoBERTa pretrained models.
|
||||
- RoBERTa has the same architecture as BERT but uses a byte-level BPE as a tokenizer (same as GPT-2) and uses a
|
||||
different pretraining scheme.
|
||||
- RoBERTa doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just
|
||||
separate your segments with the separation token `tokenizer.sep_token` (or `</s>`)
|
||||
- Same as BERT with better pretraining tricks:
|
||||
- RoBERTa doesn't have `token_type_ids`, so you don't need to indicate which token belongs to which segment. Just
|
||||
separate your segments with the separation token `tokenizer.sep_token` (or `</s>`).
|
||||
- RoBERTa is similar to BERT but with better pretraining techniques:
|
||||
|
||||
* dynamic masking: tokens are masked differently at each epoch, whereas BERT does it once and for all
|
||||
* together to reach 512 tokens (so the sentences are in an order than may span several documents)
|
||||
* train with larger batches
|
||||
* use BPE with bytes as a subunit and not characters (because of unicode characters)
|
||||
- [CamemBERT](camembert) is a wrapper around RoBERTa. Refer to this page for usage examples.
|
||||
* Dynamic masking: tokens are masked differently at each epoch, whereas BERT does it once and for all.
|
||||
* Sentence packing: Sentences are packed together to reach 512 tokens (so the sentences are in an order that may span several documents).
|
||||
* Larger batches: Training uses larger batches.
|
||||
* Byte-level BPE vocabulary: Uses BPE with bytes as a subunit instead of characters, accommodating Unicode characters.
|
||||
- [CamemBERT](camembert) is a wrapper around RoBERTa. Refer to its model page for usage examples.
|
||||
|
||||
## Resources
|
||||
|
||||
|
@ -98,7 +98,7 @@ indices = np.arange(0, total_frames, total_frames / 8).astype(int)
|
||||
video = read_video_pyav(container, indices)
|
||||
|
||||
# For better results, we recommend to prompt the model in the following format
|
||||
prompt = "USER: <video>Why is this funny? ASSISTANT:"
|
||||
prompt = "USER: <video>\nWhy is this funny? ASSISTANT:"
|
||||
inputs = processor(text=prompt, videos=video, return_tensors="pt")
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=60)
|
||||
@ -108,7 +108,7 @@ processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spac
|
||||
For multiple turns conversation change the prompt format to:
|
||||
|
||||
```bash
|
||||
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
|
||||
"USER: <video>\nWhat do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
|
||||
```
|
||||
|
||||
### Mixed Media Mode
|
||||
@ -123,7 +123,7 @@ import requests
|
||||
# Load and image and write a new prompt
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "USER: <image> How many cats are there in the image? ASSISTANT: There are two cats. USER: <video>Why is this video funny? ASSISTANT:"
|
||||
prompt = "USER: <image>\nHow many cats are there in the image? ASSISTANT: There are two cats. USER: <video>\nWhy is this video funny? ASSISTANT:"
|
||||
|
||||
inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")
|
||||
|
||||
|
@ -26,7 +26,12 @@ The abstract from the paper is the following:
|
||||
|
||||
*While existing large vision-language multimodal models focus on whole image understanding, there is a prominent gap in achieving region-specific comprehension. Current approaches that use textual coordinates or spatial encodings often fail to provide a user-friendly interface for visual prompting. To address this challenge, we introduce a novel multimodal model capable of decoding arbitrary visual prompts. This allows users to intuitively mark images and interact with the model using natural cues like a "red bounding box" or "pointed arrow". Our simple design directly overlays visual markers onto the RGB image, eliminating the need for complex region encodings, yet achieves state-of-the-art performance on region-understanding tasks like Visual7W, PointQA, and Visual Commonsense Reasoning benchmark. Furthermore, we present ViP-Bench, a comprehensive benchmark to assess the capability of models in understanding visual prompts across multiple dimensions, enabling future research in this domain. Code, data, and model are publicly available.*
|
||||
|
||||
Tips:
|
||||
The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA).
|
||||
|
||||
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada)
|
||||
|
||||
|
||||
## Usage tips:
|
||||
|
||||
- The architecture is similar than llava architecture except that the multi-modal projector takes a set of concatenated vision hidden states and has an additional layernorm layer on that module.
|
||||
|
||||
@ -34,22 +39,51 @@ Tips:
|
||||
|
||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||
|
||||
- For better results, we recommend users to prompt the model with the correct prompt format:
|
||||
- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What’s shown in this image?"},
|
||||
,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
|
||||
},
|
||||
{
|
||||
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image in more details."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
|
||||
print(text_prompt)
|
||||
>>> "###Human: <image>\nWhat’s shown in this image?###Assistant: This image shows a red stop sign.###Human: Describe the image in more details.###Assistant:"
|
||||
```
|
||||
|
||||
- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by VipLLaVa checkpoints:
|
||||
```bash
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n<prompt>###Assistant:
|
||||
```
|
||||
|
||||
For multiple turns conversation:
|
||||
|
||||
```bash
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n<prompt1>###Assistant: <answer1>###Human: <prompt2>###Assistant:
|
||||
```
|
||||
|
||||
The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA).
|
||||
|
||||
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada)
|
||||
|
||||
|
||||
## VipLlavaConfig
|
||||
|
||||
|
@ -39,6 +39,8 @@ FlashAttention-2 is experimental and may change considerably in future versions.
|
||||
FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
@ -198,6 +200,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||
|
58
docs/source/en/quantization/fbgemm_fp8.md
Normal file
58
docs/source/en/quantization/fbgemm_fp8.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# FBGEMM FP8
|
||||
|
||||
With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):
|
||||
- the weights will be quantized in 8bit (FP8) per channel
|
||||
- the activation will be quantized in 8bit (FP8) per token
|
||||
|
||||
It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.
|
||||
|
||||
> [!TIP]
|
||||
> You need a GPU with compute capability>=9 (e.g. H100)
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
```bash
|
||||
pip install --upgrade accelerate fbgemm-gpu torch
|
||||
```
|
||||
|
||||
If you are having issues with fbgemm-gpu and torch library, you might need to install the nighlty release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch)
|
||||
|
||||
|
||||
```py
|
||||
from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = FbgemmFp8Config()
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".
|
||||
|
||||
```py
|
||||
quant_path = "/path/to/save/quantized/model"
|
||||
model.save_pretrained(quant_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
|
||||
```
|
@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use.
|
||||
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
|
||||
|
232
docs/source/en/tasks/image_text_to_text.md
Normal file
232
docs/source/en/tasks/image_text_to_text.md
Normal file
@ -0,0 +1,232 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Image-text-to-text
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Image-text-to-text models, also known as vision language models (VLMs), are language models that take an image input. These models can tackle various tasks, from visual question answering to image segmentation. This task shares many similarities with image-to-text, but with some overlapping use cases like image captioning. Image-to-text models only take image inputs and often accomplish a specific task, whereas VLMs take open-ended text and image inputs and are more generalist models.
|
||||
|
||||
In this guide, we provide a brief overview of VLMs and show how to use them with Transformers for inference.
|
||||
|
||||
To begin with, there are multiple types of VLMs:
|
||||
- base models used for fine-tuning
|
||||
- chat fine-tuned models for conversation
|
||||
- instruction fine-tuned models
|
||||
|
||||
This guide focuses on inference with an instruction-tuned model.
|
||||
|
||||
Let's begin installing the dependencies.
|
||||
|
||||
```bash
|
||||
pip install -q transformers accelerate flash_attn
|
||||
```
|
||||
|
||||
Let's initialize the model and the processor.
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, Idefics2ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda")
|
||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
"HuggingFaceM4/idefics2-8b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
|
||||
```
|
||||
|
||||
This model has a [chat template](./chat_templating) that helps user parse chat outputs. Moreover, the model can also accept multiple images as input in a single conversation or message. We will now prepare the inputs.
|
||||
|
||||
The image inputs look like the following.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png" alt="Two cats sitting on a net"/>
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" alt="A bee on a pink flower"/>
|
||||
</div>
|
||||
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
img_urls =["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"]
|
||||
images = [Image.open(requests.get(img_urls[0], stream=True).raw),
|
||||
Image.open(requests.get(img_urls[1], stream=True).raw)]
|
||||
```
|
||||
|
||||
Below is an example of the chat template. We can feed conversation turns and the last message as an input by appending it at the end of the template.
|
||||
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What do we see in this image?"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "In this image we can see two cats on the nets."},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "And how about this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
We will now call the processors' [`~ProcessorMixin.apply_chat_template`] method to preprocess its output along with the image inputs.
|
||||
|
||||
```python
|
||||
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
inputs = processor(text=prompt, images=[images[0], images[1]], return_tensors="pt").to(device)
|
||||
```
|
||||
|
||||
We can now pass the preprocessed inputs to the model.
|
||||
|
||||
```python
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=500)
|
||||
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
print(generated_texts)
|
||||
## ['User: What do we see in this image? \nAssistant: In this image we can see two cats on the nets. \nUser: And how about this image? \nAssistant: In this image we can see flowers, plants and insect.']
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.
|
||||
|
||||
Assume we have an application that keeps chat history and takes in the new user input. We will preprocess the inputs as usual and initialize [`TextIteratorStreamer`] to handle the generation in a separate thread. This allows you to stream the generated text tokens in real-time. Any generation arguments can be passed to [`TextIteratorStreamer`].
|
||||
|
||||
|
||||
```python
|
||||
import time
|
||||
from transformers import TextIteratorStreamer
|
||||
from threading import Thread
|
||||
|
||||
def model_inference(
|
||||
user_prompt,
|
||||
chat_history,
|
||||
max_new_tokens,
|
||||
images
|
||||
):
|
||||
user_prompt = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
}
|
||||
chat_history.append(user_prompt)
|
||||
streamer = TextIteratorStreamer(
|
||||
processor.tokenizer,
|
||||
skip_prompt=True,
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
generation_args = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"streamer": streamer,
|
||||
"do_sample": False
|
||||
}
|
||||
|
||||
# add_generation_prompt=True makes model generate bot response
|
||||
prompt = processor.apply_chat_template(chat_history, add_generation_prompt=True)
|
||||
inputs = processor(
|
||||
text=prompt,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
generation_args.update(inputs)
|
||||
|
||||
thread = Thread(
|
||||
target=model.generate,
|
||||
kwargs=generation_args,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
acc_text = ""
|
||||
for text_token in streamer:
|
||||
time.sleep(0.04)
|
||||
acc_text += text_token
|
||||
if acc_text.endswith("<end_of_utterance>"):
|
||||
acc_text = acc_text[:-18]
|
||||
yield acc_text
|
||||
|
||||
thread.join()
|
||||
```
|
||||
|
||||
Now let's call the `model_inference` function we created and stream the values.
|
||||
|
||||
```python
|
||||
generator = model_inference(
|
||||
user_prompt="And what is in this image?",
|
||||
chat_history=messages,
|
||||
max_new_tokens=100,
|
||||
images=images
|
||||
)
|
||||
|
||||
for value in generator:
|
||||
print(value)
|
||||
|
||||
# In
|
||||
# In this
|
||||
# In this image ...
|
||||
```
|
||||
|
||||
## Fit models in smaller hardware
|
||||
|
||||
VLMs are often large and need to be optimized to fit in smaller hardware. Transformers supports many model quantization libraries, and here we will only show int8 quantization with [Quanto](./quantization/quanto#quanto). int8 quantization offers memory improvements up to 75 percent (if all weights are quantized). However it is no free lunch, since 8-bit is not a CUDA-native precision, the weights are quantized back and forth on the fly, which adds up to latency.
|
||||
|
||||
First, install dependencies.
|
||||
|
||||
```bash
|
||||
pip install -U quanto bitsandbytes
|
||||
```
|
||||
|
||||
To quantize a model during loading, we need to first create [`QuantoConfig`]. Then load the model as usual, but pass `quantization_config` during model initialization.
|
||||
|
||||
```python
|
||||
from transformers import Idefics2ForConditionalGeneration, AutoTokenizer, QuantoConfig
|
||||
|
||||
model_id = "HuggingFaceM4/idefics2-8b"
|
||||
quantization_config = QuantoConfig(weights="int8")
|
||||
quantized_model = Idefics2ForConditionalGeneration.from_pretrained(model_id, device_map="cuda", quantization_config=quantization_config)
|
||||
```
|
||||
|
||||
And that's it, we can use the model the same way with no changes.
|
||||
|
||||
## Further Reading
|
||||
|
||||
Here are some more resources for the image-text-to-text task.
|
||||
|
||||
- [Image-text-to-text task page](https://huggingface.co/tasks/image-text-to-text) covers model types, use cases, datasets, and more.
|
||||
- [Vision Language Models Explained](https://huggingface.co/blog/vlms) is a blog post that covers everything about vision language models and supervised fine-tuning using [TRL](https://huggingface.co/docs/trl/en/index).
|
@ -27,6 +27,8 @@
|
||||
title: 에이전트
|
||||
- local: llm_tutorial
|
||||
title: 대규모 언어 모델로 생성하기
|
||||
- local: in_translation
|
||||
title: (번역중)Chatting with Transformers
|
||||
title: 튜토리얼
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
@ -131,21 +133,41 @@
|
||||
title: (번역중) Notebooks with examples
|
||||
- local: community
|
||||
title: 커뮤니티 리소스
|
||||
- local: custom_tools
|
||||
title: 사용자 정의 도구와 프롬프트
|
||||
- local: troubleshooting
|
||||
title: 문제 해결
|
||||
- local: in_translation
|
||||
title: (번역중) Contribute new quantization method
|
||||
title: (번역중) Interoperability with GGUF files
|
||||
title: (번역중) 개발자 가이드
|
||||
- sections:
|
||||
- local: in_translation
|
||||
title: (번역중) Getting started
|
||||
- local: in_translation
|
||||
title: (번역중) bitsandbytes
|
||||
- local: in_translation
|
||||
title: (번역중) GPTQ
|
||||
- local: in_translation
|
||||
title: (번역중) AWQ
|
||||
- local: in_translation
|
||||
title: (번역중) AQLM
|
||||
- local: in_translation
|
||||
title: (번역중) Quanto
|
||||
- local: in_translation
|
||||
title: (번역중) EETQ
|
||||
- local: in_translation
|
||||
title: (번역중) HQQ
|
||||
- local: in_translation
|
||||
title: (번역중) Optimum
|
||||
- local: in_translation
|
||||
title: (번역중) Contribute new quantization method
|
||||
title: (번역중) 경량화 메소드
|
||||
- sections:
|
||||
- local: performance
|
||||
title: 성능 및 확장성
|
||||
- local: in_translation
|
||||
title: (번역중) Quantization
|
||||
title: (번역중) LLM inference optimization
|
||||
- sections:
|
||||
- local: in_translation
|
||||
title: (번역중) Training on one GPU
|
||||
title: (번역중) Methods and tools for efficient training on a single GPU
|
||||
- local: perf_train_gpu_many
|
||||
title: 다중 GPU에서 훈련 진행하기
|
||||
- local: in_translation
|
||||
@ -191,7 +213,7 @@
|
||||
title: 테스트
|
||||
- local: pr_checks
|
||||
title: Pull Request에 대한 검사
|
||||
title: (번역중) 기여하기
|
||||
title: 기여하기
|
||||
- sections:
|
||||
- local: philosophy
|
||||
title: 이념과 목표
|
||||
|
@ -1,22 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# 사용자 정의 도구와 프롬프트[[custom-tools-and-prompts]]
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The Agents framework has significantly changed in version v4.41.0.
|
||||
This document has been removed as it was referencing an older API.
|
||||
|
||||
We eagerly welcome new contributions for the updated API.
|
||||
|
||||
</Tip>
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
@ -484,7 +484,7 @@ def main():
|
||||
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
@ -428,7 +428,7 @@ def main():
|
||||
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -370,7 +370,7 @@ def main():
|
||||
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
@ -417,7 +417,7 @@ def main():
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(model.config.label2id.keys())}, dataset labels:"
|
||||
f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
@ -458,7 +458,7 @@ def main():
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(model.config.label2id.keys())}, dataset labels:"
|
||||
f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -2,4 +2,4 @@ datasets==2.3.2
|
||||
transformers==4.38.0
|
||||
wandb==0.13.1
|
||||
evaluate==0.2.2
|
||||
scikit-learn==1.1.2
|
||||
scikit-learn==1.5.0
|
@ -187,7 +187,7 @@ rsa==4.8
|
||||
s3transfer==0.3.7
|
||||
sacrebleu==1.5.1
|
||||
sacremoses==0.0.49
|
||||
scikit-learn==1.0.2
|
||||
scikit-learn==1.5.0
|
||||
scipy==1.8.0
|
||||
segments==2.2.0
|
||||
sentencepiece==0.1.96
|
||||
|
@ -59,7 +59,7 @@ class GroupedBatchSampler(BatchSampler):
|
||||
|
||||
def __init__(self, sampler, group_ids, batch_size):
|
||||
if not isinstance(sampler, Sampler):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
"sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
||||
)
|
||||
self.sampler = sampler
|
||||
|
@ -48,7 +48,7 @@ def convert_to_float(value):
|
||||
if isinstance(value, int):
|
||||
return float(value)
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Argument value is not a string. Can't parse it as float")
|
||||
raise TypeError("Argument value is not a string. Can't parse it as float")
|
||||
sanitized = value
|
||||
|
||||
try:
|
||||
@ -158,7 +158,7 @@ def _respect_conditions(table, row, conditions):
|
||||
cmp_value = _normalize_for_match(cmp_value)
|
||||
|
||||
if not isinstance(table_value, type(cmp_value)):
|
||||
raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value)))
|
||||
raise TypeError("Type difference {} != {}".format(type(table_value), type(cmp_value)))
|
||||
|
||||
if not _compare(cond.operator, table_value, cmp_value):
|
||||
return False
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
@ -326,7 +326,7 @@ def main():
|
||||
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -374,7 +374,7 @@ def main():
|
||||
label_to_id = label_name_to_id # Use the model's labels
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: "
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels:"
|
||||
f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
|
||||
)
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.43.0.dev0")
|
||||
check_min_version("4.43.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
4
setup.py
4
setup.py
@ -132,7 +132,7 @@ _deps = [
|
||||
"librosa",
|
||||
"nltk",
|
||||
"natten>=0.14.6,<0.15.0",
|
||||
"numpy>=1.17,<2.0",
|
||||
"numpy>=1.17",
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.43.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.43.0.dev0"
|
||||
__version__ = "4.43.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -249,6 +249,11 @@ _import_structure = {
|
||||
"CanineConfig",
|
||||
"CanineTokenizer",
|
||||
],
|
||||
"models.chameleon": [
|
||||
"ChameleonConfig",
|
||||
"ChameleonProcessor",
|
||||
"ChameleonVQVAEConfig",
|
||||
],
|
||||
"models.chinese_clip": [
|
||||
"ChineseCLIPConfig",
|
||||
"ChineseCLIPProcessor",
|
||||
@ -929,6 +934,7 @@ _import_structure = {
|
||||
"AwqConfig",
|
||||
"BitsAndBytesConfig",
|
||||
"EetqConfig",
|
||||
"FbgemmFp8Config",
|
||||
"GPTQConfig",
|
||||
"HqqConfig",
|
||||
"QuantoConfig",
|
||||
@ -1125,6 +1131,7 @@ else:
|
||||
_import_structure["models.bit"].extend(["BitImageProcessor"])
|
||||
_import_structure["models.blip"].extend(["BlipImageProcessor"])
|
||||
_import_structure["models.bridgetower"].append("BridgeTowerImageProcessor")
|
||||
_import_structure["models.chameleon"].append("ChameleonImageProcessor")
|
||||
_import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"])
|
||||
_import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"])
|
||||
_import_structure["models.conditional_detr"].extend(
|
||||
@ -1286,8 +1293,9 @@ else:
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
)
|
||||
_import_structure["modeling_flash_attention_utils"]: []
|
||||
_import_structure["modeling_flash_attention_utils"] = []
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel"]
|
||||
|
||||
# PyTorch models structure
|
||||
@ -1608,6 +1616,15 @@ else:
|
||||
"load_tf_weights_in_canine",
|
||||
]
|
||||
)
|
||||
_import_structure["models.chameleon"].extend(
|
||||
[
|
||||
"ChameleonForConditionalGeneration",
|
||||
"ChameleonModel",
|
||||
"ChameleonPreTrainedModel",
|
||||
"ChameleonProcessor",
|
||||
"ChameleonVQVAE",
|
||||
]
|
||||
)
|
||||
_import_structure["models.chinese_clip"].extend(
|
||||
[
|
||||
"ChineseCLIPModel",
|
||||
@ -4890,6 +4907,11 @@ if TYPE_CHECKING:
|
||||
CanineConfig,
|
||||
CanineTokenizer,
|
||||
)
|
||||
from .models.chameleon import (
|
||||
ChameleonConfig,
|
||||
ChameleonProcessor,
|
||||
ChameleonVQVAEConfig,
|
||||
)
|
||||
from .models.chinese_clip import (
|
||||
ChineseCLIPConfig,
|
||||
ChineseCLIPProcessor,
|
||||
@ -5645,6 +5667,7 @@ if TYPE_CHECKING:
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
@ -5807,6 +5830,7 @@ if TYPE_CHECKING:
|
||||
from .models.bit import BitImageProcessor
|
||||
from .models.blip import BlipImageProcessor
|
||||
from .models.bridgetower import BridgeTowerImageProcessor
|
||||
from .models.chameleon import ChameleonImageProcessor
|
||||
from .models.chinese_clip import (
|
||||
ChineseCLIPFeatureExtractor,
|
||||
ChineseCLIPImageProcessor,
|
||||
@ -5987,6 +6011,7 @@ if TYPE_CHECKING:
|
||||
WatermarkLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.albert import (
|
||||
AlbertForMaskedLM,
|
||||
@ -6254,6 +6279,13 @@ if TYPE_CHECKING:
|
||||
CaninePreTrainedModel,
|
||||
load_tf_weights_in_canine,
|
||||
)
|
||||
from .models.chameleon import (
|
||||
ChameleonForConditionalGeneration,
|
||||
ChameleonModel,
|
||||
ChameleonPreTrainedModel,
|
||||
ChameleonProcessor,
|
||||
ChameleonVQVAE,
|
||||
)
|
||||
from .models.chinese_clip import (
|
||||
ChineseCLIPModel,
|
||||
ChineseCLIPPreTrainedModel,
|
||||
|
@ -107,7 +107,7 @@ class AgentImage(AgentType, ImageType):
|
||||
elif isinstance(value, np.ndarray):
|
||||
self._tensor = torch.tensor(value)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||
|
||||
def _ipython_display_(self, include=None, exclude=None):
|
||||
"""
|
||||
|
@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfEngine, MessageRole
|
||||
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
||||
from .prompts import (
|
||||
DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
||||
SYSTEM_PROMPT_FACTS,
|
||||
SYSTEM_PROMPT_FACTS_UPDATE,
|
||||
SYSTEM_PROMPT_PLAN,
|
||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
USER_PROMPT_FACTS_UPDATE,
|
||||
USER_PROMPT_PLAN,
|
||||
USER_PROMPT_PLAN_UPDATE,
|
||||
)
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
|
||||
def parse_code_blob(code_blob: str) -> str:
|
||||
try:
|
||||
pattern = r"```(?:py|python)?\n(.*?)```"
|
||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||
match = re.search(pattern, code_blob, re.DOTALL)
|
||||
return match.group(1).strip()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
|
||||
f"""
|
||||
The code blob you used is invalid: due to the following error: {e}
|
||||
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
|
||||
Thoughts: Your thoughts
|
||||
Code:
|
||||
```py
|
||||
# Your python code here
|
||||
```<end_action>"""
|
||||
)
|
||||
|
||||
|
||||
@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||
tool_call = parse_json_blob(json_blob)
|
||||
if "action" in tool_call and "action_input" in tool_call:
|
||||
return tool_call["action"], tool_call["action_input"]
|
||||
elif "action" in tool_call:
|
||||
return tool_call["action"], None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
||||
@ -208,7 +229,7 @@ class Toolbox:
|
||||
The tool to add to the toolbox.
|
||||
"""
|
||||
if tool.name in self._tools:
|
||||
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
|
||||
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def remove_tool(self, tool_name: str):
|
||||
@ -359,12 +380,8 @@ class Agent:
|
||||
"""Get the toolbox currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def initialize_for_run(self, task: str, **kwargs):
|
||||
def initialize_for_run(self):
|
||||
self.token_count = 0
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox,
|
||||
self.system_prompt_template,
|
||||
@ -380,7 +397,7 @@ class Agent:
|
||||
self.logger.debug("System prompt is as follows:")
|
||||
self.logger.debug(self.system_prompt)
|
||||
|
||||
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
@ -390,43 +407,51 @@ class Agent:
|
||||
"role": MessageRole.USER,
|
||||
"content": "Task: " + self.logs[0]["task"],
|
||||
}
|
||||
memory = [prompt_message, task_message]
|
||||
if summary_mode:
|
||||
memory = [task_message]
|
||||
else:
|
||||
memory = [prompt_message, task_message]
|
||||
for i, step_log in enumerate(self.logs[1:]):
|
||||
if "llm_output" in step_log:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
|
||||
if "llm_output" in step_log and not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
|
||||
memory.append(thought_message)
|
||||
if "facts" in step_log:
|
||||
thought_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
|
||||
}
|
||||
memory.append(thought_message)
|
||||
|
||||
if "error" in step_log:
|
||||
message_content = (
|
||||
"Error: "
|
||||
+ str(step_log["error"])
|
||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||
)
|
||||
elif "observation" in step_log:
|
||||
message_content = f"Observation: {step_log['observation']}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
memory.append(tool_response_message)
|
||||
if "plan" in step_log and not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
|
||||
memory.append(thought_message)
|
||||
|
||||
if "tool_call" in step_log and summary_mode:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
if "task" in step_log:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": "New task:\n" + step_log["task"],
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
if "error" in step_log or "observation" in step_log:
|
||||
if "error" in step_log:
|
||||
message_content = (
|
||||
f"[OUTPUT OF STEP {i}] Error: "
|
||||
+ str(step_log["error"])
|
||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||
)
|
||||
elif "observation" in step_log:
|
||||
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
memory.append(tool_response_message)
|
||||
|
||||
if len(memory) % 3 == 0:
|
||||
reminder_content = (
|
||||
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
|
||||
)
|
||||
reminder_content += "\nHere is a summary of your past tool calls and their results:"
|
||||
for j in range(i + 1):
|
||||
reminder_content += "\nStep " + str(j + 1)
|
||||
if "tool_call" in self.logs[j]:
|
||||
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
|
||||
if self.memory_verbose:
|
||||
if "observation" in self.logs[j]:
|
||||
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
|
||||
if "error" in self.logs[j]:
|
||||
reminder_content += "\nError:" + str(self.logs[j]["error"])
|
||||
memory.append(
|
||||
{
|
||||
"role": MessageRole.USER,
|
||||
"content": reminder_content,
|
||||
}
|
||||
)
|
||||
return memory
|
||||
|
||||
def get_succinct_logs(self):
|
||||
@ -459,7 +484,7 @@ class Agent:
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
if tool_name not in self.toolbox.tools:
|
||||
@ -559,7 +584,11 @@ class CodeAgent(Agent):
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
self.initialize_for_run()
|
||||
|
||||
# Run LLM
|
||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
||||
@ -598,7 +627,8 @@ class CodeAgent(Agent):
|
||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||
output = self.python_evaluator(
|
||||
code_action,
|
||||
available_tools,
|
||||
static_tools=available_tools,
|
||||
custom_tools={},
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
@ -623,6 +653,7 @@ class ReactAgent(Agent):
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -632,6 +663,7 @@ class ReactAgent(Agent):
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
self.planning_interval = planning_interval
|
||||
|
||||
def provide_final_answer(self, task) -> str:
|
||||
"""
|
||||
@ -655,11 +687,13 @@ class ReactAgent(Agent):
|
||||
except Exception as e:
|
||||
return f"Error in generating final llm output: {e}."
|
||||
|
||||
def run(self, task: str, stream: bool = False, **kwargs):
|
||||
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
||||
"""
|
||||
Runs the agent for the given task.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
|
||||
Example:
|
||||
```py
|
||||
from transformers.agents import ReactCodeAgent
|
||||
@ -667,14 +701,23 @@ class ReactAgent(Agent):
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
if stream:
|
||||
return self.stream_run(task, **kwargs)
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
if reset:
|
||||
self.initialize_for_run()
|
||||
else:
|
||||
return self.direct_run(task, **kwargs)
|
||||
|
||||
def stream_run(self, task: str, **kwargs):
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
self.logs.append({"task": task})
|
||||
if stream:
|
||||
return self.stream_run(task)
|
||||
else:
|
||||
return self.direct_run(task)
|
||||
|
||||
def stream_run(self, task: str):
|
||||
"""
|
||||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
@ -700,13 +743,16 @@ class ReactAgent(Agent):
|
||||
|
||||
yield final_answer
|
||||
|
||||
def direct_run(self, task: str, **kwargs):
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
|
||||
def direct_run(self, task: str):
|
||||
"""
|
||||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
try:
|
||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||
step_logs = self.step()
|
||||
if "final_answer" in step_logs:
|
||||
final_answer = step_logs["final_answer"]
|
||||
@ -726,6 +772,96 @@ class ReactAgent(Agent):
|
||||
|
||||
return final_answer
|
||||
|
||||
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
||||
"""
|
||||
Used periodically by the agent to plan the next steps to reach the objective.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
||||
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
||||
"""
|
||||
if is_first_step:
|
||||
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
||||
message_prompt_task = {
|
||||
"role": MessageRole.USER,
|
||||
"content": f"""Here is the task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
Now begin!""",
|
||||
}
|
||||
|
||||
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
|
||||
|
||||
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
|
||||
message_user_prompt_plan = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
answer_facts=answer_facts,
|
||||
),
|
||||
}
|
||||
answer_plan = self.llm_engine(
|
||||
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
||||
)
|
||||
|
||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||
```
|
||||
{answer_plan}
|
||||
```"""
|
||||
final_facts_redaction = f"""Here are the facts that I know so far:
|
||||
```
|
||||
{answer_facts}
|
||||
```""".strip()
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.debug("===== Initial plan: =====")
|
||||
self.logger.debug(final_plan_redaction)
|
||||
else: # update plan
|
||||
agent_memory = self.write_inner_memory_from_logs(
|
||||
summary_mode=False
|
||||
) # This will not log the plan but will log facts
|
||||
|
||||
# Redact updated facts
|
||||
facts_update_system_prompt = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": SYSTEM_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
||||
|
||||
# Redact updated plan
|
||||
plan_update_message = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
|
||||
}
|
||||
plan_update_message_user = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
facts_update=facts_update,
|
||||
remaining_steps=(self.max_iterations - iteration),
|
||||
),
|
||||
}
|
||||
plan_update = self.llm_engine(
|
||||
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
||||
)
|
||||
|
||||
# Log final facts and plan
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||
```
|
||||
{facts_update}
|
||||
```"""
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.debug("===== Updated plan: =====")
|
||||
self.logger.debug(final_plan_redaction)
|
||||
|
||||
|
||||
class ReactJsonAgent(ReactAgent):
|
||||
"""
|
||||
@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
planning_interval=planning_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
|
||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||
if tool_name == "final_answer":
|
||||
if isinstance(arguments, dict):
|
||||
answer = arguments["answer"]
|
||||
if "answer" in arguments:
|
||||
answer = arguments["answer"]
|
||||
if (
|
||||
isinstance(answer, str) and answer in self.state.keys()
|
||||
): # if the answer is a state variable, return the value
|
||||
answer = self.state[answer]
|
||||
else:
|
||||
answer = arguments
|
||||
else:
|
||||
answer = arguments
|
||||
if answer in self.state: # if the answer is a state variable, return the value
|
||||
answer = self.state[answer]
|
||||
current_step_logs["final_answer"] = answer
|
||||
return current_step_logs
|
||||
else:
|
||||
@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
planning_interval=planning_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||
self.available_tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
} # This list can be augmented by the code agent creating some new functions
|
||||
self.custom_tools = {}
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
|
||||
try:
|
||||
result = self.python_evaluator(
|
||||
code_action,
|
||||
tools=self.available_tools,
|
||||
static_tools={
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
},
|
||||
custom_tools=self.custom_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.logger.log(32, information)
|
||||
current_step_logs["observation"] = information
|
||||
except Exception as e:
|
||||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||
if "'dict' object has no attribute 'read'" in str(e):
|
||||
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
|
||||
|
||||
def forward(self, code):
|
||||
output = str(
|
||||
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
|
||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
||||
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
|
||||
|
||||
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
|
||||
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
|
||||
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
|
||||
|
||||
---
|
||||
### 1. Facts given in the task
|
||||
List here the specific facts given in the task that could help you (there might be nothing here).
|
||||
|
||||
### 2. Facts to look up
|
||||
List here any facts that we may need to look up.
|
||||
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
|
||||
|
||||
### 3. Facts to derive
|
||||
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
|
||||
|
||||
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts to look up
|
||||
### 3. Facts to derive
|
||||
Do not add anything else."""
|
||||
|
||||
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
||||
|
||||
USER_PROMPT_PLAN = """
|
||||
Here is your task:
|
||||
|
||||
Task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Your plan can leverage any of these tools:
|
||||
{tool_descriptions}
|
||||
|
||||
List of facts that you know:
|
||||
```
|
||||
{answer_facts}
|
||||
```
|
||||
|
||||
Now begin! Write your plan below."""
|
||||
|
||||
SYSTEM_PROMPT_FACTS_UPDATE = """
|
||||
You are a world expert at gathering known and unknown facts based on a conversation.
|
||||
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts that we have learned
|
||||
### 3. Facts still to look up
|
||||
### 4. Facts still to derive
|
||||
Find the task and history below."""
|
||||
|
||||
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
|
||||
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
|
||||
Please update your list of facts based on the previous history, and provide these headings:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts that we have learned
|
||||
### 3. Facts still to look up
|
||||
### 4. Facts still to derive
|
||||
|
||||
Now write your new list of facts below."""
|
||||
|
||||
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||
|
||||
You have been given a task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
|
||||
If the previous tries so far have met some success, you can make an updated plan based on these actions.
|
||||
If you are stalled, you can make a completely new plan starting from scratch.
|
||||
"""
|
||||
|
||||
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
You have access to these tools:
|
||||
{tool_descriptions}
|
||||
|
||||
Here is the up to date list of facts that you know:
|
||||
```
|
||||
{facts_update}
|
||||
```
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
Beware that you have {remaining_steps} steps remaining.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
||||
|
||||
Now write your new plan below."""
|
||||
|
||||
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Here is my new/updated plan of action to solve the task:
|
||||
```
|
||||
{plan_update}
|
||||
```"""
|
||||
|
@ -18,8 +18,17 @@ import ast
|
||||
import builtins
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import is_pandas_available
|
||||
|
||||
|
||||
if is_pandas_available():
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
"""
|
||||
@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [
|
||||
"unicodedata",
|
||||
]
|
||||
|
||||
PRINT_OUTPUTS = ""
|
||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||
|
||||
|
||||
class BreakException(Exception):
|
||||
@ -75,8 +85,8 @@ def get_iterable(obj):
|
||||
raise InterpreterError("Object is not iterable")
|
||||
|
||||
|
||||
def evaluate_unaryop(expression, state, tools):
|
||||
operand = evaluate_ast(expression.operand, state, tools)
|
||||
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
||||
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
||||
if isinstance(expression.op, ast.USub):
|
||||
return -operand
|
||||
elif isinstance(expression.op, ast.UAdd):
|
||||
@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools):
|
||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_lambda(lambda_expression, state, tools):
|
||||
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||
args = [arg.arg for arg in lambda_expression.args.args]
|
||||
|
||||
def lambda_func(*values):
|
||||
new_state = state.copy()
|
||||
for arg, value in zip(args, values):
|
||||
new_state[arg] = value
|
||||
return evaluate_ast(lambda_expression.body, new_state, tools)
|
||||
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
||||
|
||||
return lambda_func
|
||||
|
||||
|
||||
def evaluate_while(while_loop, state, tools):
|
||||
def evaluate_while(while_loop, state, static_tools, custom_tools):
|
||||
max_iterations = 1000
|
||||
iterations = 0
|
||||
while evaluate_ast(while_loop.test, state, tools):
|
||||
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
|
||||
for node in while_loop.body:
|
||||
try:
|
||||
evaluate_ast(node, state, tools)
|
||||
evaluate_ast(node, state, static_tools, custom_tools)
|
||||
except BreakException:
|
||||
return None
|
||||
except ContinueException:
|
||||
@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools):
|
||||
return None
|
||||
|
||||
|
||||
def create_function(func_def, state, tools):
|
||||
def create_function(func_def, state, static_tools, custom_tools):
|
||||
def new_func(*args, **kwargs):
|
||||
func_state = state.copy()
|
||||
arg_names = [arg.arg for arg in func_def.args.args]
|
||||
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
|
||||
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
||||
|
||||
# Apply default values
|
||||
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||
@ -158,7 +168,7 @@ def create_function(func_def, state, tools):
|
||||
result = None
|
||||
try:
|
||||
for stmt in func_def.body:
|
||||
result = evaluate_ast(stmt, func_state, tools)
|
||||
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
||||
except ReturnException as e:
|
||||
result = e.value
|
||||
return result
|
||||
@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body):
|
||||
return type(class_name, tuple(class_bases), class_dict)
|
||||
|
||||
|
||||
def evaluate_function_def(func_def, state, tools):
|
||||
tools[func_def.name] = create_function(func_def, state, tools)
|
||||
return tools[func_def.name]
|
||||
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
||||
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
||||
return custom_tools[func_def.name]
|
||||
|
||||
|
||||
def evaluate_class_def(class_def, state, tools):
|
||||
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||
class_name = class_def.name
|
||||
bases = [evaluate_ast(base, state, tools) for base in class_def.bases]
|
||||
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
||||
class_dict = {}
|
||||
|
||||
for stmt in class_def.body:
|
||||
if isinstance(stmt, ast.FunctionDef):
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
||||
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
elif isinstance(target, ast.Attribute):
|
||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
|
||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
else:
|
||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||
|
||||
@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools):
|
||||
return new_class
|
||||
|
||||
|
||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||
# Helper function to get current value and set new value based on the target type
|
||||
def get_current_value(target):
|
||||
if isinstance(target, ast.Name):
|
||||
return state.get(target.id, 0)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
key = evaluate_ast(target.slice, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||
return obj[key]
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
return getattr(obj, target.attr)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
return tuple(get_current_value(elt) for elt in target.elts)
|
||||
@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||
|
||||
current_value = get_current_value(expression.target)
|
||||
value_to_add = evaluate_ast(expression.value, state, tools)
|
||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
|
||||
# Determine the operation and apply it
|
||||
if isinstance(expression.op, ast.Add):
|
||||
@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||
|
||||
# Update the state
|
||||
set_value(expression.target, updated_value, state, tools)
|
||||
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
||||
|
||||
return updated_value
|
||||
|
||||
|
||||
def evaluate_boolop(node, state, tools):
|
||||
def evaluate_boolop(node, state, static_tools, custom_tools):
|
||||
if isinstance(node.op, ast.And):
|
||||
for value in node.values:
|
||||
if not evaluate_ast(value, state, tools):
|
||||
if not evaluate_ast(value, state, static_tools, custom_tools):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(node.op, ast.Or):
|
||||
for value in node.values:
|
||||
if evaluate_ast(value, state, tools):
|
||||
if evaluate_ast(value, state, static_tools, custom_tools):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def evaluate_binop(binop, state, tools):
|
||||
def evaluate_binop(binop, state, static_tools, custom_tools):
|
||||
# Recursively evaluate the left and right operands
|
||||
left_val = evaluate_ast(binop.left, state, tools)
|
||||
right_val = evaluate_ast(binop.right, state, tools)
|
||||
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
|
||||
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
|
||||
|
||||
# Determine the operation based on the type of the operator in the BinOp
|
||||
if isinstance(binop.op, ast.Add):
|
||||
@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools):
|
||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
|
||||
if len(assign.targets) == 1:
|
||||
target = assign.targets[0]
|
||||
set_value(target, result, state, tools)
|
||||
set_value(target, result, state, static_tools, custom_tools)
|
||||
else:
|
||||
if len(assign.targets) != len(result):
|
||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||
for tgt, val in zip(assign.targets, result):
|
||||
set_value(tgt, val, state, tools)
|
||||
expanded_values = []
|
||||
for tgt in assign.targets:
|
||||
if isinstance(tgt, ast.Starred):
|
||||
expanded_values.extend(result)
|
||||
else:
|
||||
expanded_values.append(result)
|
||||
for tgt, val in zip(assign.targets, expanded_values):
|
||||
set_value(tgt, val, state, static_tools, custom_tools)
|
||||
return result
|
||||
|
||||
|
||||
def set_value(target, value, state, tools):
|
||||
def set_value(target, value, state, static_tools, custom_tools):
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in tools:
|
||||
if target.id in static_tools:
|
||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||
state[target.id] = value
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(value, tuple):
|
||||
raise InterpreterError("Cannot unpack non-tuple value")
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
value = tuple(value)
|
||||
else:
|
||||
raise InterpreterError("Cannot unpack non-tuple value")
|
||||
if len(target.elts) != len(value):
|
||||
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||
for i, elem in enumerate(target.elts):
|
||||
set_value(elem, value[i], state, tools)
|
||||
set_value(elem, value[i], state, static_tools, custom_tools)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
key = evaluate_ast(target.slice, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||
obj[key] = value
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
setattr(obj, target.attr, value)
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
def evaluate_call(call, state, static_tools, custom_tools):
|
||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||
raise InterpreterError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
||||
)
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
obj = evaluate_ast(call.func.value, state, tools)
|
||||
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
|
||||
func_name = call.func.attr
|
||||
if not hasattr(obj, func_name):
|
||||
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||
func = getattr(obj, func_name)
|
||||
|
||||
elif isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
if func_name in state:
|
||||
func = state[func_name]
|
||||
elif func_name in tools:
|
||||
func = tools[func_name]
|
||||
elif func_name in static_tools:
|
||||
func = static_tools[func_name]
|
||||
elif func_name in custom_tools:
|
||||
func = custom_tools[func_name]
|
||||
elif func_name in ERRORS:
|
||||
func = ERRORS[func_name]
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
||||
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
|
||||
else:
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||
|
||||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
|
||||
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
|
||||
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
|
||||
args.extend(unpacked)
|
||||
else:
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
||||
|
||||
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
||||
# Instantiate the class using its constructor
|
||||
@ -397,24 +433,31 @@ def evaluate_call(call, state, tools):
|
||||
output = " ".join(map(str, args))
|
||||
global PRINT_OUTPUTS
|
||||
PRINT_OUTPUTS += output + "\n"
|
||||
# cap the number of lines
|
||||
return output
|
||||
else: # Assume it's a callable object
|
||||
output = func(*args, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(index, slice):
|
||||
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||
|
||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||
parent_object = value.obj
|
||||
return parent_object.loc[index]
|
||||
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
|
||||
return value[index]
|
||||
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
|
||||
return value[index]
|
||||
elif isinstance(index, slice):
|
||||
return value[index]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Ensure the index is within bounds
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||
return value[int(index)]
|
||||
elif isinstance(value, str):
|
||||
# Ensure the index is within bounds
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||
return value[index]
|
||||
@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools):
|
||||
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
def evaluate_name(name, state, static_tools, custom_tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
elif name.id in tools:
|
||||
return tools[name.id]
|
||||
elif name.id in static_tools:
|
||||
return static_tools[name.id]
|
||||
elif name.id in ERRORS:
|
||||
return ERRORS[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
@ -440,9 +483,9 @@ def evaluate_name(name, state, tools):
|
||||
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
||||
def evaluate_condition(condition, state, static_tools, custom_tools):
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
||||
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
result = True
|
||||
@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools):
|
||||
|
||||
for op, comparator in zip(ops, comparators):
|
||||
if op == ast.Eq:
|
||||
result = result and (current_left == comparator)
|
||||
current_result = current_left == comparator
|
||||
elif op == ast.NotEq:
|
||||
result = result and (current_left != comparator)
|
||||
current_result = current_left != comparator
|
||||
elif op == ast.Lt:
|
||||
result = result and (current_left < comparator)
|
||||
current_result = current_left < comparator
|
||||
elif op == ast.LtE:
|
||||
result = result and (current_left <= comparator)
|
||||
current_result = current_left <= comparator
|
||||
elif op == ast.Gt:
|
||||
result = result and (current_left > comparator)
|
||||
current_result = current_left > comparator
|
||||
elif op == ast.GtE:
|
||||
result = result and (current_left >= comparator)
|
||||
current_result = current_left >= comparator
|
||||
elif op == ast.Is:
|
||||
result = result and (current_left is comparator)
|
||||
current_result = current_left is comparator
|
||||
elif op == ast.IsNot:
|
||||
result = result and (current_left is not comparator)
|
||||
current_result = current_left is not comparator
|
||||
elif op == ast.In:
|
||||
result = result and (current_left in comparator)
|
||||
current_result = current_left in comparator
|
||||
elif op == ast.NotIn:
|
||||
result = result and (current_left not in comparator)
|
||||
current_result = current_left not in comparator
|
||||
else:
|
||||
raise InterpreterError(f"Operator not supported: {op}")
|
||||
|
||||
result = result & current_result
|
||||
current_left = comparator
|
||||
if not result:
|
||||
|
||||
if isinstance(result, bool) and not result:
|
||||
break
|
||||
|
||||
return result
|
||||
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
def evaluate_if(if_statement, state, static_tools, custom_tools):
|
||||
result = None
|
||||
test_result = evaluate_ast(if_statement.test, state, tools)
|
||||
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
|
||||
if test_result:
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_for(for_loop, state, tools):
|
||||
def evaluate_for(for_loop, state, static_tools, custom_tools):
|
||||
result = None
|
||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
||||
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
|
||||
for counter in iterator:
|
||||
if isinstance(for_loop.target, ast.Tuple):
|
||||
for i, elem in enumerate(for_loop.target.elts):
|
||||
state[elem.id] = counter[i]
|
||||
else:
|
||||
state[for_loop.target.id] = counter
|
||||
set_value(for_loop.target, counter, state, static_tools, custom_tools)
|
||||
for node in for_loop.body:
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
line_result = evaluate_ast(node, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
except BreakException:
|
||||
@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools):
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_listcomp(listcomp, state, tools):
|
||||
result = []
|
||||
for generator in listcomp.generators:
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||
def inner_evaluate(generators, index, current_state):
|
||||
if index >= len(generators):
|
||||
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
||||
generator = generators[index]
|
||||
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
||||
result = []
|
||||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
new_state = current_state.copy()
|
||||
if isinstance(generator.target, ast.Tuple):
|
||||
for idx, elem in enumerate(generator.target.elts):
|
||||
new_state[elem.id] = value[idx]
|
||||
else:
|
||||
new_state[generator.target.id] = value
|
||||
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
|
||||
result.append(evaluate_ast(listcomp.elt, new_state, tools))
|
||||
return result
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||
return result
|
||||
|
||||
return inner_evaluate(listcomp.generators, 0, state)
|
||||
|
||||
|
||||
def evaluate_try(try_node, state, tools):
|
||||
def evaluate_try(try_node, state, static_tools, custom_tools):
|
||||
try:
|
||||
for stmt in try_node.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
except Exception as e:
|
||||
matched = False
|
||||
for handler in try_node.handlers:
|
||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)):
|
||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
||||
matched = True
|
||||
if handler.name:
|
||||
state[handler.name] = e
|
||||
for stmt in handler.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
break
|
||||
if not matched:
|
||||
raise e
|
||||
else:
|
||||
if try_node.orelse:
|
||||
for stmt in try_node.orelse:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
finally:
|
||||
if try_node.finalbody:
|
||||
for stmt in try_node.finalbody:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
|
||||
|
||||
def evaluate_raise(raise_node, state, tools):
|
||||
def evaluate_raise(raise_node, state, static_tools, custom_tools):
|
||||
if raise_node.exc is not None:
|
||||
exc = evaluate_ast(raise_node.exc, state, tools)
|
||||
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
|
||||
else:
|
||||
exc = None
|
||||
if raise_node.cause is not None:
|
||||
cause = evaluate_ast(raise_node.cause, state, tools)
|
||||
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
|
||||
else:
|
||||
cause = None
|
||||
if exc is not None:
|
||||
@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools):
|
||||
raise InterpreterError("Re-raise is not supported without an active exception")
|
||||
|
||||
|
||||
def evaluate_assert(assert_node, state, tools):
|
||||
test_result = evaluate_ast(assert_node.test, state, tools)
|
||||
def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
||||
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
|
||||
if not test_result:
|
||||
if assert_node.msg:
|
||||
msg = evaluate_ast(assert_node.msg, state, tools)
|
||||
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
|
||||
raise AssertionError(msg)
|
||||
else:
|
||||
# Include the failing condition in the assertion message
|
||||
@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools):
|
||||
raise AssertionError(f"Assertion failed: {test_code}")
|
||||
|
||||
|
||||
def evaluate_with(with_node, state, tools):
|
||||
def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||
contexts = []
|
||||
for item in with_node.items:
|
||||
context_expr = evaluate_ast(item.context_expr, state, tools)
|
||||
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
||||
if item.optional_vars:
|
||||
state[item.optional_vars.id] = context_expr.__enter__()
|
||||
contexts.append(state[item.optional_vars.id])
|
||||
@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools):
|
||||
|
||||
try:
|
||||
for stmt in with_node.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
except Exception as e:
|
||||
for context in reversed(contexts):
|
||||
context.__exit__(type(e), e, e.__traceback__)
|
||||
@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools):
|
||||
context.__exit__(None, None, None)
|
||||
|
||||
|
||||
def import_modules(expression, state, authorized_imports):
|
||||
def check_module_authorized(module_name):
|
||||
module_path = module_name.split(".")
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
if isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if check_module_authorized(alias.name):
|
||||
module = import_module(alias.name)
|
||||
state[alias.asname or alias.name] = module
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||
)
|
||||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if check_module_authorized(expression.module):
|
||||
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
||||
result = {}
|
||||
for gen in dictcomp.generators:
|
||||
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
|
||||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
||||
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
||||
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(
|
||||
expression: ast.AST,
|
||||
state: Dict[str, Any],
|
||||
tools: Dict[str, Callable],
|
||||
static_tools: Dict[str, Callable],
|
||||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
):
|
||||
"""
|
||||
@ -632,146 +719,128 @@ def evaluate_ast(
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpreterError`.
|
||||
static_tools (`Dict[str, Callable]`):
|
||||
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
||||
custom_tools (`Dict[str, Callable]`):
|
||||
Functions that may be called during the evaluation. These static_tools can be overwritten.
|
||||
authorized_imports (`List[str]`):
|
||||
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
||||
Add more at your own risk!
|
||||
"""
|
||||
global OPERATIONS_COUNT
|
||||
if OPERATIONS_COUNT >= MAX_OPERATIONS:
|
||||
raise InterpreterError(
|
||||
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
||||
)
|
||||
OPERATIONS_COUNT += 1
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignment which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
return evaluate_assign(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.AugAssign):
|
||||
return evaluate_augassign(expression, state, tools)
|
||||
return evaluate_augassign(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
return evaluate_call(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
|
||||
elif isinstance(expression, ast.ListComp):
|
||||
return evaluate_listcomp(expression, state, tools)
|
||||
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
return evaluate_unaryop(expression, state, tools)
|
||||
return evaluate_unaryop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Starred):
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.BoolOp):
|
||||
# Boolean operation -> evaluate the operation
|
||||
return evaluate_boolop(expression, state, tools)
|
||||
return evaluate_boolop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Break):
|
||||
raise BreakException()
|
||||
elif isinstance(expression, ast.Continue):
|
||||
raise ContinueException()
|
||||
elif isinstance(expression, ast.BinOp):
|
||||
# Binary operation -> execute operation
|
||||
return evaluate_binop(expression, state, tools)
|
||||
return evaluate_binop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Compare):
|
||||
# Comparison -> evaluate the comparison
|
||||
return evaluate_condition(expression, state, tools)
|
||||
return evaluate_condition(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Lambda):
|
||||
return evaluate_lambda(expression, state, tools)
|
||||
return evaluate_lambda(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.FunctionDef):
|
||||
return evaluate_function_def(expression, state, tools)
|
||||
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.For):
|
||||
# For loop -> execute the loop
|
||||
return evaluate_for(expression, state, tools)
|
||||
return evaluate_for(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
return evaluate_if(expression, state, static_tools, custom_tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
return evaluate_name(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
return evaluate_subscript(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.IfExp):
|
||||
test_val = evaluate_ast(expression.test, state, tools)
|
||||
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
|
||||
if test_val:
|
||||
return evaluate_ast(expression.body, state, tools)
|
||||
return evaluate_ast(expression.body, state, static_tools, custom_tools)
|
||||
else:
|
||||
return evaluate_ast(expression.orelse, state, tools)
|
||||
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Attribute):
|
||||
obj = evaluate_ast(expression.value, state, tools)
|
||||
return getattr(obj, expression.attr)
|
||||
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
return getattr(value, expression.attr)
|
||||
elif isinstance(expression, ast.Slice):
|
||||
return slice(
|
||||
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
|
||||
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
|
||||
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
|
||||
evaluate_ast(expression.lower, state, static_tools, custom_tools)
|
||||
if expression.lower is not None
|
||||
else None,
|
||||
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
||||
if expression.upper is not None
|
||||
else None,
|
||||
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
||||
)
|
||||
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
|
||||
result = []
|
||||
vars = {}
|
||||
for generator in expression.generators:
|
||||
var_name = generator.target.id
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
for value in iter_value:
|
||||
vars[var_name] = value
|
||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
||||
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
|
||||
result.append(elem)
|
||||
return result
|
||||
elif isinstance(expression, ast.DictComp):
|
||||
result = {}
|
||||
for gen in expression.generators:
|
||||
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
|
||||
state[gen.target.id] = container
|
||||
key = evaluate_ast(expression.key, state, tools)
|
||||
value = evaluate_ast(expression.value, state, tools)
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if alias.name in authorized_imports:
|
||||
module = __import__(alias.name)
|
||||
state[alias.asname or alias.name] = module
|
||||
else:
|
||||
raise InterpreterError(f"Import of {alias.name} is not allowed.")
|
||||
return None
|
||||
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.While):
|
||||
return evaluate_while(expression, state, tools)
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if expression.module in authorized_imports:
|
||||
module = __import__(expression.module)
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||
return None
|
||||
return evaluate_while(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
||||
return import_modules(expression, state, authorized_imports)
|
||||
elif isinstance(expression, ast.ClassDef):
|
||||
return evaluate_class_def(expression, state, tools)
|
||||
return evaluate_class_def(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Try):
|
||||
return evaluate_try(expression, state, tools)
|
||||
return evaluate_try(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Raise):
|
||||
return evaluate_raise(expression, state, tools)
|
||||
return evaluate_raise(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Assert):
|
||||
return evaluate_assert(expression, state, tools)
|
||||
return evaluate_assert(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.With):
|
||||
return evaluate_with(expression, state, tools)
|
||||
return evaluate_with(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Set):
|
||||
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
|
||||
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
||||
elif isinstance(expression, ast.Return):
|
||||
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
|
||||
raise ReturnException(
|
||||
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
||||
)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||
@ -779,7 +848,8 @@ def evaluate_ast(
|
||||
|
||||
def evaluate_python_code(
|
||||
code: str,
|
||||
tools: Optional[Dict[str, Callable]] = None,
|
||||
static_tools: Optional[Dict[str, Callable]] = None,
|
||||
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
):
|
||||
@ -792,9 +862,12 @@ def evaluate_python_code(
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpreterError`.
|
||||
static_tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation.
|
||||
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
||||
custom_tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation.
|
||||
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
@ -806,20 +879,34 @@ def evaluate_python_code(
|
||||
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
||||
if state is None:
|
||||
state = {}
|
||||
if tools is None:
|
||||
tools = {}
|
||||
if static_tools is None:
|
||||
static_tools = {}
|
||||
if custom_tools is None:
|
||||
custom_tools = {}
|
||||
result = None
|
||||
global PRINT_OUTPUTS
|
||||
PRINT_OUTPUTS = ""
|
||||
global OPERATIONS_COUNT
|
||||
OPERATIONS_COUNT = 0
|
||||
for node in expression.body:
|
||||
try:
|
||||
result = evaluate_ast(node, state, tools, authorized_imports)
|
||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
except InterpreterError as e:
|
||||
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
msg = ""
|
||||
if len(PRINT_OUTPUTS) > 0:
|
||||
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||
else:
|
||||
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
|
||||
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
raise InterpreterError(msg)
|
||||
finally:
|
||||
state["print_outputs"] = PRINT_OUTPUTS
|
||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||
state["print_outputs"] = PRINT_OUTPUTS
|
||||
else:
|
||||
state["print_outputs"] = (
|
||||
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
|
||||
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
|
||||
)
|
||||
|
||||
return result
|
||||
|
@ -862,8 +862,18 @@ class StaticCache(Cache):
|
||||
k_out.copy_(key_states)
|
||||
v_out.copy_(value_states)
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
||||
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
||||
# operation, that avoids copies and uses less memory.
|
||||
try:
|
||||
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
|
||||
cache_position.to(device=k_out.device)
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
@ -958,8 +968,14 @@ class SlidingWindowCache(StaticCache):
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
try:
|
||||
cache_position.to(device=k_out.device)
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
@ -1249,3 +1265,77 @@ class HybridCache(Cache):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
Arguments:
|
||||
config: MambaConfig
|
||||
max_batch_size: int
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
|
||||
Attributes:
|
||||
dtype: torch.dtype
|
||||
intermediate_size: int
|
||||
ssm_state_size: int
|
||||
conv_kernel_size: int
|
||||
conv_states: torch.Tensor [layer_idx, batch_size, intermediate_size, conv_kernel_size]
|
||||
ssm_states: torch.Tensor [layer_idx, batch_size, intermediate_size, ssm_state_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_batch_size = max_batch_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
|
||||
self.conv_states: torch.Tensor = torch.zeros(
|
||||
config.num_hidden_layers,
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.ssm_states: torch.Tensor = torch.zeros(
|
||||
config.num_hidden_layers,
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(self.conv_states)
|
||||
torch._dynamo.mark_static_address(self.ssm_states)
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
||||
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
|
||||
return self.ssm_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
self.conv_states.zero_()
|
||||
self.ssm_states.zero_()
|
||||
|
@ -202,9 +202,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
"""
|
||||
|
||||
def _get_audio_input():
|
||||
ds = load_dataset(
|
||||
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
speech_samples = ds.sort("id").select(range(2))[:2]["audio"]
|
||||
raw_samples = [x["array"] for x in speech_samples]
|
||||
return raw_samples
|
||||
|
@ -1004,7 +1004,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
elif isinstance(old_v, float):
|
||||
v = float(v)
|
||||
elif not isinstance(old_v, str):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
||||
)
|
||||
|
||||
|
@ -47,11 +47,11 @@ class XnliProcessor(DataProcessor):
|
||||
text_b = line[1]
|
||||
label = "contradiction" if line[2] == "contradictory" else line[2]
|
||||
if not isinstance(text_a, str):
|
||||
raise ValueError(f"Training input {text_a} is not a string")
|
||||
raise TypeError(f"Training input {text_a} is not a string")
|
||||
if not isinstance(text_b, str):
|
||||
raise ValueError(f"Training input {text_b} is not a string")
|
||||
raise TypeError(f"Training input {text_b} is not a string")
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(f"Training label {label} is not a string")
|
||||
raise TypeError(f"Training label {label} is not a string")
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@ -70,11 +70,11 @@ class XnliProcessor(DataProcessor):
|
||||
text_b = line[7]
|
||||
label = line[1]
|
||||
if not isinstance(text_a, str):
|
||||
raise ValueError(f"Training input {text_a} is not a string")
|
||||
raise TypeError(f"Training input {text_a} is not a string")
|
||||
if not isinstance(text_b, str):
|
||||
raise ValueError(f"Training input {text_b} is not a string")
|
||||
raise TypeError(f"Training input {text_b} is not a string")
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(f"Training label {label} is not a string")
|
||||
raise TypeError(f"Training label {label} is not a string")
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
@ -38,7 +38,7 @@ deps = {
|
||||
"librosa": "librosa",
|
||||
"nltk": "nltk",
|
||||
"natten": "natten>=0.14.6,<0.15.0",
|
||||
"numpy": "numpy>=1.17,<2.0",
|
||||
"numpy": "numpy>=1.17",
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
|
@ -156,7 +156,7 @@ class PhrasalConstraint(Constraint):
|
||||
|
||||
def does_advance(self, token_id: int):
|
||||
if not isinstance(token_id, int):
|
||||
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||
|
||||
if self.completed:
|
||||
return False
|
||||
@ -165,7 +165,7 @@ class PhrasalConstraint(Constraint):
|
||||
|
||||
def update(self, token_id: int):
|
||||
if not isinstance(token_id, int):
|
||||
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||
raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
|
||||
|
||||
stepped = False
|
||||
completed = False
|
||||
@ -300,7 +300,7 @@ class DisjunctiveConstraint(Constraint):
|
||||
|
||||
def does_advance(self, token_id: int):
|
||||
if not isinstance(token_id, int):
|
||||
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||
|
||||
next_tokens = self.trie.next_tokens(self.current_seq)
|
||||
|
||||
@ -308,7 +308,7 @@ class DisjunctiveConstraint(Constraint):
|
||||
|
||||
def update(self, token_id: int):
|
||||
if not isinstance(token_id, int):
|
||||
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||
raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
|
||||
|
||||
stepped = False
|
||||
completed = False
|
||||
@ -432,7 +432,7 @@ class ConstraintListState:
|
||||
|
||||
def add(self, token_id: int):
|
||||
if not isinstance(token_id, int):
|
||||
raise ValueError(f"`token_id` should be an `int`, but is `{token_id}`.")
|
||||
raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
|
||||
|
||||
complete, stepped = False, False
|
||||
|
||||
|
@ -1760,7 +1760,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
|
||||
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
|
||||
@ -1812,7 +1812,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
|
||||
>>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
|
||||
@ -1901,7 +1901,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt")
|
||||
>>> input_features = inputs.input_features
|
||||
|
||||
|
@ -32,6 +32,7 @@ from ..cache_utils import (
|
||||
EncoderDecoderCache,
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
MambaCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
SlidingWindowCache,
|
||||
@ -116,7 +117,12 @@ logger = logging.get_logger(__name__)
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
||||
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
||||
"static": StaticCache,
|
||||
"sliding_window": SlidingWindowCache,
|
||||
"hybrid": HybridCache,
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
|
||||
|
||||
@ -748,12 +754,12 @@ class GenerationMixin:
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config.eos_token_id, list):
|
||||
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
|
||||
elif isinstance(generation_config.eos_token_id, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
@ -857,31 +863,31 @@ class GenerationMixin:
|
||||
processors.append(
|
||||
NoBadWordsLogitsProcessor(
|
||||
generation_config.bad_words_ids,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
and generation_config._eos_token_tensor is not None
|
||||
and generation_config.min_length > 0
|
||||
):
|
||||
processors.append(
|
||||
MinLengthLogitsProcessor(
|
||||
generation_config.min_length,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if (
|
||||
generation_config.min_new_tokens is not None
|
||||
and generation_config.eos_token_id is not None
|
||||
and generation_config._eos_token_tensor is not None
|
||||
and generation_config.min_new_tokens > 0
|
||||
):
|
||||
processors.append(
|
||||
MinNewTokensLengthLogitsProcessor(
|
||||
input_ids_seq_length,
|
||||
generation_config.min_new_tokens,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
@ -912,7 +918,7 @@ class GenerationMixin:
|
||||
processors.append(
|
||||
ExponentialDecayLengthPenalty(
|
||||
generation_config.exponential_decay_length_penalty,
|
||||
generation_config.eos_token_id,
|
||||
generation_config._eos_token_tensor,
|
||||
input_ids_seq_length,
|
||||
)
|
||||
)
|
||||
@ -991,8 +997,8 @@ class GenerationMixin:
|
||||
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
|
||||
)
|
||||
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
|
||||
if generation_config.eos_token_id is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
||||
if generation_config._eos_token_tensor is not None:
|
||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
||||
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
||||
return criteria
|
||||
|
||||
@ -1343,13 +1349,15 @@ class GenerationMixin:
|
||||
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
|
||||
) -> Tuple[GenerationConfig, Dict]:
|
||||
"""
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs.
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||
function handles retrocompatibility with respect to configuration files.
|
||||
"""
|
||||
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
|
||||
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
|
||||
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
|
||||
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
using_model_generation_config = False
|
||||
if generation_config is None:
|
||||
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||
# three conditions must be met
|
||||
@ -1372,6 +1380,7 @@ class GenerationMixin:
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
using_model_generation_config = True
|
||||
generation_config = self.generation_config
|
||||
|
||||
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
|
||||
@ -1389,6 +1398,16 @@ class GenerationMixin:
|
||||
else:
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
||||
if not using_model_generation_config:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.decoder_start_token_id is None:
|
||||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
@ -1431,8 +1450,9 @@ class GenerationMixin:
|
||||
not hasattr(self, "_cache")
|
||||
or (not isinstance(cache_to_check, cache_cls))
|
||||
or cache_to_check.max_batch_size != max_batch_size
|
||||
or cache_to_check.max_cache_len < max_cache_len
|
||||
)
|
||||
if cache_implementation != "mamba":
|
||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||
|
||||
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
||||
need_new_cache = (
|
||||
@ -1486,52 +1506,43 @@ class GenerationMixin:
|
||||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
|
||||
def _tensor_or_none(token_kwargs, token_self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
token = token_kwargs if token_kwargs is not None else token_self
|
||||
# Convert special tokens to tensors
|
||||
def _tensor_or_none(token, device=None):
|
||||
if token is None:
|
||||
return token
|
||||
elif isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
|
||||
device = device if device is not None else self.device
|
||||
if isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_id = _tensor_or_none(
|
||||
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
|
||||
)
|
||||
eos_token_id = _tensor_or_none(
|
||||
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
|
||||
)
|
||||
pad_token_id = _tensor_or_none(
|
||||
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
|
||||
)
|
||||
decoder_start_token_id = _tensor_or_none(
|
||||
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
|
||||
)
|
||||
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||
|
||||
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
|
||||
if self.config.is_encoder_decoder:
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
decoder_start_token_tensor = (
|
||||
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
|
||||
)
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||
eos_token_id = eos_token_id.unsqueeze(0)
|
||||
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
||||
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
||||
|
||||
# Set pad token if unset (and there are conditions to do so)
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
if pad_token_tensor is None and eos_token_tensor is not None:
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
pad_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
||||
pad_token_tensor = eos_token_tensor[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
|
||||
# we can't infer attn mask if pad token is set to be eos token in model's generation config
|
||||
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
|
||||
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning_once(
|
||||
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
|
||||
@ -1540,21 +1551,26 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# Sanity checks/warnings
|
||||
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
||||
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
|
||||
raise ValueError(
|
||||
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
||||
)
|
||||
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
|
||||
if eos_token_tensor is not None and (
|
||||
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
|
||||
):
|
||||
logger.warning(
|
||||
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
|
||||
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
|
||||
"stop until the maximum length is reached. Depending on other flags, it may even crash."
|
||||
)
|
||||
|
||||
# Update generation config with the updated special tokens tensors
|
||||
generation_config.bos_token_id = bos_token_id
|
||||
generation_config.eos_token_id = eos_token_id
|
||||
generation_config.pad_token_id = pad_token_id
|
||||
generation_config.decoder_start_token_id = decoder_start_token_id
|
||||
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
||||
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
||||
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
||||
generation_config._bos_token_tensor = bos_token_tensor
|
||||
generation_config._eos_token_tensor = eos_token_tensor
|
||||
generation_config._pad_token_tensor = pad_token_tensor
|
||||
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
@ -1689,10 +1705,10 @@ class GenerationMixin:
|
||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
generation_config._pad_token_tensor is not None
|
||||
and batch_size > 1
|
||||
and len(inputs_tensor.shape) == 2
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
@ -1709,7 +1725,7 @@ class GenerationMixin:
|
||||
|
||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||
@ -1724,7 +1740,7 @@ class GenerationMixin:
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
@ -1750,9 +1766,13 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
use_dynamic_cache_by_default = False
|
||||
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
|
||||
if "mamba" in self.__class__.__name__.lower():
|
||||
cache_name = "cache_params"
|
||||
else:
|
||||
cache_name = "past_key_values"
|
||||
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
|
||||
raise ValueError(
|
||||
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
|
||||
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
||||
"Cache object) is unsupported. Please use only one of the two."
|
||||
)
|
||||
elif generation_config.cache_implementation is not None:
|
||||
@ -1762,7 +1782,7 @@ class GenerationMixin:
|
||||
"This model does not support `cache_implementation='static'`. Please check the following "
|
||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||
)
|
||||
model_kwargs["past_key_values"] = self._get_cache(
|
||||
model_kwargs[cache_name] = self._get_cache(
|
||||
generation_config.cache_implementation,
|
||||
getattr(generation_config, "num_beams", 1) * batch_size,
|
||||
generation_config.max_length,
|
||||
@ -1793,23 +1813,23 @@ class GenerationMixin:
|
||||
"Please install it via with `pip install hqq`"
|
||||
)
|
||||
|
||||
model_kwargs["past_key_values"] = cache_class(cache_config)
|
||||
model_kwargs[cache_name] = cache_class(cache_config)
|
||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||
# keeps copying the cache thus using much more memory
|
||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
||||
past = model_kwargs.get("past_key_values", None)
|
||||
past = model_kwargs.get(cache_name, None)
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
if past is None:
|
||||
model_kwargs["past_key_values"] = (
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache()
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
)
|
||||
use_dynamic_cache_by_default = True
|
||||
elif isinstance(past, tuple):
|
||||
model_kwargs["past_key_values"] = (
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache.from_legacy_cache(past)
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache.from_legacy_cache(past)
|
||||
@ -2268,7 +2288,7 @@ class GenerationMixin:
|
||||
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
||||
# init values
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -2475,7 +2495,7 @@ class GenerationMixin:
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
top_k = generation_config.top_k
|
||||
penalty_alpha = generation_config.penalty_alpha
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -2866,7 +2886,7 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3073,8 +3093,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3355,8 +3375,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -3647,8 +3667,8 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -4261,7 +4281,7 @@ def _split(data, full_batch_size: int, split_size: int = None):
|
||||
for i in range(0, full_batch_size, split_size)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected attribute type: {type(data)}")
|
||||
raise TypeError(f"Unexpected attribute type: {type(data)}")
|
||||
|
||||
|
||||
def _split_model_inputs(
|
||||
@ -4368,7 +4388,7 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
# If the elements are integers or floats, return a tensor
|
||||
return torch.tensor(data)
|
||||
else:
|
||||
raise ValueError(f"Unexpected attribute type: {type(data[0])}")
|
||||
raise TypeError(f"Unexpected attribute type: {type(data[0])}")
|
||||
|
||||
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
|
||||
concatenated_data = {
|
||||
|
@ -544,7 +544,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
response.raise_for_status()
|
||||
return Image.open(BytesIO(response.content))
|
||||
else:
|
||||
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
||||
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
||||
|
||||
|
||||
ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
|
||||
|
@ -75,7 +75,7 @@ def to_channel_dimension_format(
|
||||
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
||||
"""
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
if input_channel_dim is None:
|
||||
input_channel_dim = infer_channel_dimension_format(image)
|
||||
@ -121,7 +121,7 @@ def rescale(
|
||||
`np.ndarray`: The rescaled image.
|
||||
"""
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
rescaled_image = image * scale
|
||||
if data_format is not None:
|
||||
@ -453,7 +453,7 @@ def center_crop(
|
||||
return_numpy = True if return_numpy is None else return_numpy
|
||||
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
if not isinstance(size, Iterable) or len(size) != 2:
|
||||
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
||||
|
@ -64,7 +64,6 @@ if is_vision_available():
|
||||
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
||||
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
||||
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
||||
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
@ -378,7 +377,7 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = image
|
||||
else:
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
||||
)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
|
@ -45,6 +45,7 @@ _import_structure = {
|
||||
"unset_hf_deepspeed_config",
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
"GGUF_TENSOR_MAPPING",
|
||||
@ -126,6 +127,7 @@ if TYPE_CHECKING:
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .ggml import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
GGUF_TENSOR_MAPPING,
|
||||
|
@ -199,7 +199,7 @@ def get_modules_to_fuse(model, quantization_config):
|
||||
The quantization configuration to use.
|
||||
"""
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")
|
||||
raise TypeError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")
|
||||
|
||||
# Always default to `quantization_config.modules_to_fuse`
|
||||
if quantization_config.modules_to_fuse is not None:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user