mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
41 Commits
v4.42.1
...
dynamic_le
Author | SHA1 | Date | |
---|---|---|---|
450b1d26e4 | |||
4df64e62e9 | |||
84e694d4ec | |||
8d4e17b311 | |||
f47f4a8998 | |||
843876bdb7 | |||
dcd292bab6 | |||
3ca52cf3ab | |||
33ef0b14ea | |||
b6f30f5376 | |||
8029b7fc2a | |||
f93b239ed7 | |||
01cd35fb97 | |||
888a2c0007 | |||
b7eaf50cd5 | |||
c1098f9766 | |||
81e795a600 | |||
9ab68d0d9f | |||
dc72fd7edd | |||
7f91f168a1 | |||
f91c16d270 | |||
cd0935dd55 | |||
82486e5995 | |||
a9701953ff | |||
57d7594a79 | |||
93cd94b79d | |||
cf85e86e9a | |||
3345ae733b | |||
e655029515 | |||
bbf1e61864 | |||
cb298978ad | |||
82a1fc7256 | |||
5e89b335ab | |||
0142aab7f8 | |||
1c68f2cafb | |||
464aa74659 | |||
e44b878c02 | |||
75a6319864 | |||
727eea4ab0 | |||
0cf60f13ab | |||
4aa17d0069 |
@ -382,6 +382,8 @@
|
||||
title: Fuyu
|
||||
- local: model_doc/gemma
|
||||
title: Gemma
|
||||
- local: model_doc/gemma2
|
||||
title: Gemma2
|
||||
- local: model_doc/openai-gpt
|
||||
title: GPT
|
||||
- local: model_doc/gpt_neo
|
||||
|
@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Funnel Transformer](model_doc/funnel) | ✅ | ✅ | ❌ |
|
||||
| [Fuyu](model_doc/fuyu) | ✅ | ❌ | ❌ |
|
||||
| [Gemma](model_doc/gemma) | ✅ | ❌ | ✅ |
|
||||
| [Gemma2](model_doc/gemma2) | ✅ | ❌ | ❌ |
|
||||
| [GIT](model_doc/git) | ✅ | ❌ | ❌ |
|
||||
| [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ |
|
||||
| [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ |
|
||||
|
@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- get_seq_length
|
||||
- reset
|
||||
|
||||
[[autodoc]] EncoderDecoderCache
|
||||
- get_seq_length
|
||||
- to_legacy_cache
|
||||
- from_legacy_cache
|
||||
- reset
|
||||
- reorder_cache
|
||||
|
||||
## Watermark Utils
|
||||
|
||||
|
58
docs/source/en/model_doc/gemma2.md
Normal file
58
docs/source/en/model_doc/gemma2.md
Normal file
@ -0,0 +1,58 @@
|
||||
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Gemma2
|
||||
|
||||
## Overview
|
||||
|
||||
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google.
|
||||
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).
|
||||
|
||||
The abstract from the blog post is the following:
|
||||
|
||||
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
|
||||
|
||||
Tips:
|
||||
|
||||
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
|
||||
|
||||
|
||||
## Gemma2Config
|
||||
|
||||
[[autodoc]] Gemma2Config
|
||||
|
||||
## Gemma2Model
|
||||
|
||||
[[autodoc]] Gemma2Model
|
||||
- forward
|
||||
|
||||
## Gemma2ForCausalLM
|
||||
|
||||
[[autodoc]] Gemma2ForCausalLM
|
||||
- forward
|
||||
|
||||
## Gemma2ForSequenceClassification
|
||||
|
||||
[[autodoc]] Gemma2ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Gemma2ForTokenClassification
|
||||
|
||||
[[autodoc]] Gemma2ForTokenClassification
|
||||
- forward
|
@ -16,6 +16,15 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Llama3
|
||||
|
||||
```py3
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Meta-Llama-3-8B"
|
||||
|
||||
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
|
||||
pipeline("Hey how are you doing today?")
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
@ -66,20 +75,7 @@ model = AutoModelForCausalLM.from_pretrained("/output/path")
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed.
|
||||
|
||||
|
||||
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
|
||||
|
||||
## Quick usage
|
||||
|
||||
```py3
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Meta-Llama-3-8B"
|
||||
|
||||
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
|
||||
pipeline("Hey how are you doing today?")
|
||||
```
|
||||
|
||||
## Resources
|
||||
A ton of cool resources are already available on the documentation page of [~llama2], inviting contributors to add new resources curated for Llama3 here! 🤗
|
||||
A ton of cool resources are already available on the documentation page of [Llama2](./llama2), inviting contributors to add new resources curated for Llama3 here! 🤗
|
||||
|
@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
||||
>>> # Select an audio file and read it:
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> audio_sample = ds[0]["audio"]
|
||||
>>> waveform = audio_sample["array"]
|
||||
>>> sampling_rate = audio_sample["sampling_rate"]
|
||||
|
||||
>>> # Load the Whisper model in Hugging Face format:
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
||||
|
||||
>>> # Use the model and processor to transcribe the audio:
|
||||
>>> input_features = processor(
|
||||
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
|
||||
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> # Generate token ids
|
||||
@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
Whisper is compatible with the following optimisations:
|
||||
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
|
||||
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
|
||||
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
|
||||
|
||||
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
>>> # Select an audio file and read it:
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> audio_sample = ds[0]["audio"]
|
||||
|
||||
>>> # Load the Whisper model with SDPA attention
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
|
||||
|
||||
>>> # Enable static cache and compile the forward pass
|
||||
>>> model.generation_config.cache_implementation = "static"
|
||||
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
>>> # Use the model and processor to transcribe the audio:
|
||||
>>> input_features = processor(
|
||||
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> # Compile the forward pass
|
||||
>>> _ = model.generate(input_features)
|
||||
|
||||
>>> # Generate token ids using compiled graph (fast!)
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
||||
>>> # Decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> transcription[0]
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
For more details on each optimisation, refer to the documentation linked above.
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
@ -43,6 +43,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
|
||||
@ -202,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||
|
@ -15,7 +15,7 @@
|
||||
title: Préparation des données
|
||||
- local: in_translation
|
||||
title: Fine-tune un modèle pré-entraîné
|
||||
- local: in_translation
|
||||
- local: run_scripts_fr
|
||||
title: Entraînement avec un script
|
||||
- local: in_translation
|
||||
title: Entraînement distribué avec 🤗 Accelerate
|
||||
|
355
docs/source/fr/run_scripts_fr.md
Normal file
355
docs/source/fr/run_scripts_fr.md
Normal file
@ -0,0 +1,355 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Entraîner avec un script
|
||||
|
||||
En plus des [notebooks](./notebooks) de 🤗 Transformers, il existe également des exemples de scripts démontrant comment entraîner un modèle pour une tâche avec [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch), [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow) ou [JAX/Flax](https://github.com/huggingface/transformers/tree/main/examples/flax).
|
||||
|
||||
|
||||
Vous trouverez également des scripts que nous avons utilisé dans nos [projets de recherche](https://github.com/huggingface/transformers/tree/main/examples/research_projects) et des [exemples "legacy"](https://github.com/huggingface/transformers/tree/main/examples/legacy) qui sont des contributions de la communauté. Ces scripts ne sont pas activement maintenus et nécessitent une version spécifique de 🤗 Transformers qui sera probablement incompatible avec la dernière version de la librairie.
|
||||
|
||||
Les exemples de scripts ne sont pas censés fonctionner immédiatement pour chaque problème, et il se peut que vous ayez besoin d'adapter le script au problème que vous essayez de résoudre. Pour vous aider dans cette tâche, la plupart des scripts exposent entièrement la manière dont les données sont prétraitées, vous permettant de les modifier selon vos besoins.
|
||||
|
||||
Pour toute fonctionnalité que vous souhaitez implémenter dans un script d'exemple, veuillez en discuter sur le [forum](https://discuss.huggingface.co/) ou dans une [issue](https://github.com/huggingface/transformers/issues) avant de soumettre une Pull Request. Bien que nous acceptions les corrections de bugs, il est peu probable que nous fusionnions une Pull Request (opération "merge" dans Git) ajoutant plus de fonctionnalités au détriment de la lisibilité.
|
||||
|
||||
Ce guide vous montrera comment exécuter un script d'entraînement de résumé en exemple avec [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) et [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization). Tous les exemples sont censés fonctionner avec les deux frameworks, sauf indication contraire.
|
||||
|
||||
## Configuration
|
||||
|
||||
Pour exécuter avec succès la dernière version des scripts d'exemple, vous devez **installer 🤗 Transformers à partir du code source** dans un nouvel environnement virtuel :
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Pour les versions plus anciennes des exemples de scripts, cliquez sur le bouton ci-dessous :
|
||||
|
||||
<details>
|
||||
<summary>Exemples pour les anciennes versions de Transformers 🤗</summary>
|
||||
<ul>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.5.1/examples">v4.5.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.4.2/examples">v4.4.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.3.3/examples">v4.3.3</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.2.2/examples">v4.2.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.1.1/examples">v4.1.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.0.1/examples">v4.0.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.5.1/examples">v3.5.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.4.0/examples">v3.4.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.3.1/examples">v3.3.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.2.0/examples">v3.2.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.1.0/examples">v3.1.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v3.0.2/examples">v3.0.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.11.0/examples">v2.11.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.10.0/examples">v2.10.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.9.1/examples">v2.9.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.8.0/examples">v2.8.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.7.0/examples">v2.7.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.6.0/examples">v2.6.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.5.1/examples">v2.5.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.4.0/examples">v2.4.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.3.0/examples">v2.3.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.2.0/examples">v2.2.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.1.0/examples">v2.1.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v2.0.0/examples">v2.0.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v1.2.0/examples">v1.2.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v1.1.0/examples">v1.1.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v1.0.0/examples">v1.0.0</a></li>
|
||||
</ul>
|
||||
</details>
|
||||
|
||||
Ensuite, changez votre clone actuel de 🤗 Transformers pour une version spécifique, comme par exemple v3.5.1 :
|
||||
|
||||
```bash
|
||||
git checkout tags/v3.5.1
|
||||
```
|
||||
|
||||
Après avoir configuré la bonne version de la librairie, accédez au dossier d'exemple de votre choix et installez les prérequis spécifiques à l'exemple.
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Exécuter un script
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
Le script d'exemple télécharge et prétraite un jeu de données à partir de la bibliothèque 🤗 [Datasets](https://huggingface.co/docs/datasets/). Ensuite, le script affine un ensemble de données à l'aide de [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) sur une architecture qui prend en charge la tâche de résumé. L'exemple suivant montre comment ajuster le modèle [T5-small](https://huggingface.co/google-t5/t5-small) sur les données [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail). Le modèle T5 nécessite un argument supplémentaire `source_prefix` en raison de la façon dont il a été entraîné. Cette invite permet à T5 de savoir qu'il s'agit d'une tâche de résumé.
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
Le script d'exemple télécharge et prétraite un jeu de données à partir de la bibliothèque 🤗 [Datasets](https://huggingface.co/docs/datasets/). Ensuite, le script ajuste un modèle à l'aide de Keras sur une architecture qui prend en charge la tâche de résumé. L'exemple suivant montre comment ajuster le modèle [T5-small](https://huggingface.co/google-t5/t5-small) sur le jeu de données [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail). Le modèle T5 nécessite un argument supplémentaire source_prefix en raison de la façon dont il a été entraîné. Cette invite permet à T5 de savoir qu'il s'agit d'une tâche de résumé.
|
||||
|
||||
```bash
|
||||
python examples/tensorflow/summarization/run_summarization.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size 8 \
|
||||
--per_device_eval_batch_size 16 \
|
||||
--num_train_epochs 3 \
|
||||
--do_train \
|
||||
--do_eval
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Entraînement distribué et précision mixte
|
||||
|
||||
[Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) prend en charge l'entraînement distribué et la précision mixte, ce qui signifie que vous pouvez également les utiliser dans un script. Pour activer ces deux fonctionnalités :
|
||||
|
||||
- Ajoutez l'argument fp16 pour activer la précision mixte.
|
||||
- Définissez le nombre de GPU à utiliser avec l'argument `nproc_per_node`.
|
||||
|
||||
```bash
|
||||
torchrun \
|
||||
--nproc_per_node 8 pytorch/summarization/run_summarization.py \
|
||||
--fp16 \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
Les scripts TensorFlow utilisent une Strategie en Miroir [`MirroredStrategy`](https://www.tensorflow.org/guide/distributed_training#mirroredstrategy) pour l'entraînement distribué, et vous n'avez pas besoin d'ajouter d'arguments supplémentaires au script d'entraînement. Le script TensorFlow utilisera plusieurs GPU par défaut s'ils sont disponibles.
|
||||
|
||||
## Exécuter un script sur un TPU
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
Les unités de traitement de tenseurs (UTT) (TPU) sont spécialement conçues pour accélérer les performances. PyTorch prend en charge les TPU avec le compilateur de deep learning [XLA](https://www.tensorflow.org/xla). Pour utiliser un TPU, lancez le script xla_spawn.py et utilisez l'argument num_cores pour définir le nombre de cœurs TPU que vous souhaitez utilise
|
||||
|
||||
```bash
|
||||
python xla_spawn.py --num_cores 8 \
|
||||
summarization/run_summarization.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
Les scripts TensorFlow utilisent une [`TPUStrategy`](https://www.tensorflow.org/guide/distributed_training#tpustrategy) pour l'entraînement sur TPU. Pour utiliser un TPU, passez le nom de la ressource TPU à l'argument tpu.
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--tpu name_of_tpu_resource \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size 8 \
|
||||
--per_device_eval_batch_size 16 \
|
||||
--num_train_epochs 3 \
|
||||
--do_train \
|
||||
--do_eval
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Exécuter un script avec 🤗 Accelerate
|
||||
|
||||
🤗 [Accelerate](https://huggingface.co/docs/accelerate) est une bibliothèque uniquement pour PyTorch qui offre une méthode unifiée pour entraîner un modèle sur plusieurs types de configurations (CPU uniquement, plusieurs GPU, TPU) tout en maintenant une visibilité complète sur la boucle d'entraînement PyTorch. Assurez-vous que vous avez installé 🤗 Accelerate si ce n'est pas déjà le cas.
|
||||
|
||||
> Note : Comme Accelerate est en développement rapide, la version git d'accelerate doit être installée pour exécuter les scripts.
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/accelerate
|
||||
```
|
||||
|
||||
Au lieu du script `run_summarization.py`, vous devez utiliser le script `run_summarization_no_trainer.py`. Les scripts compatibles avec 🤗 Accelerate auront un fichier `task_no_trainer.py` dans le dossier. Commencez par exécuter la commande suivante pour créer et enregistrer un fichier de configuration.
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Testez votre configuration pour vous assurer qu'elle est correctement configurée :
|
||||
|
||||
```bash
|
||||
accelerate test
|
||||
```
|
||||
|
||||
Maintenant, vous êtes prêt à lancer l'entraînement :
|
||||
|
||||
```bash
|
||||
accelerate launch run_summarization_no_trainer.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir ~/tmp/tst-summarization
|
||||
```
|
||||
|
||||
## Utiliser un jeu de données personnalisé
|
||||
|
||||
Le script de résumé prend en charge les jeux de données personnalisés tant qu'ils sont au format CSV ou JSON Line. Lorsque vous utilisez votre propre jeu de données, vous devez spécifier plusieurs arguments supplémentaires :
|
||||
|
||||
- `train_file` et `validation_file` spécifient le chemin vers vos fichiers d'entraînement et de validation.
|
||||
- `text_column` est le texte d'entrée à résumer.
|
||||
- `summary_column` est le texte cible à produire.
|
||||
|
||||
Un exemple de script de résumé utilisant un ensemble de données personnalisé ressemblerait à ceci :
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--train_file path_to_csv_or_jsonlines_file \
|
||||
--validation_file path_to_csv_or_jsonlines_file \
|
||||
--text_column text_column_name \
|
||||
--summary_column summary_column_name \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
## Tester un script
|
||||
Il est souvent judicieux d'exécuter votre script sur un plus petit nombre d'exemples de jeu de données pour s'assurer que tout fonctionne comme prévu avant de s'engager sur un jeu de données complet qui pourrait prendre des heures à traiter. Utilisez les arguments suivants pour tronquer le jeu de données à un nombre maximal d'échantillons :
|
||||
|
||||
- `max_train_samples`
|
||||
- `max_eval_samples`
|
||||
- `max_predict_samples`
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py \
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--max_train_samples 50 \
|
||||
--max_eval_samples 50 \
|
||||
--max_predict_samples 50 \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
Tous les scripts d'exemple ne prennent pas en charge l'argument `max_predict_samples`. Si vous n'êtes pas sûr que votre script prenne en charge cet argument, ajoutez l'argument `-h` pour vérifier.
|
||||
|
||||
```bash
|
||||
examples/pytorch/summarization/run_summarization.py -h
|
||||
```
|
||||
|
||||
## Reprendre l'entraînement à partir d'un point de contrôle
|
||||
|
||||
Une autre option utile est de reprendre l'entraînement à partir d'un point de contrôle précédent. Cela vous permettra de reprendre là où vous vous étiez arrêté sans recommencer si votre entraînement est interrompu. Il existe deux méthodes pour reprendre l'entraînement à partir d'un point de contrôle.
|
||||
|
||||
La première méthode utilise l'argument `output_dir previous_output_dir` pour reprendre l'entraînement à partir du dernier point de contrôle stocké dans `output_dir`. Dans ce cas, vous devez supprimer l'argument `overwrite_output_dir`.
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--output_dir previous_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
La seconde méthode utilise l'argument `resume_from_checkpoint path_to_specific_checkpoint` pour reprendre l'entraînement à partir d'un dossier de point de contrôle spécifique.
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--resume_from_checkpoint path_to_specific_checkpoint \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
## Partage ton modèle
|
||||
|
||||
Tous les scripts peuvent télécharger votre modèle final sur le Model Hub. Assurez-vous que vous êtes connecté à Hugging Face avant de commencer :
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Ensuite, ajoutez l'argument `push_to_hub` au script. Cet argument créera un dépôt avec votre nom d'utilisateur Hugging Face et le nom du dossier spécifié dans `output_dir`.
|
||||
|
||||
|
||||
Pour donner un nom spécifique à votre dépôt, utilisez l'argument `push_to_hub_model_id` pour l'ajouter. Le dépôt sera automatiquement listé sous votre namespace.
|
||||
|
||||
L'exemple suivant montre comment télécharger un modèle avec un nom de dépôt spécifique :
|
||||
|
||||
```bash
|
||||
python examples/pytorch/summarization/run_summarization.py
|
||||
--model_name_or_path google-t5/t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--dataset_name cnn_dailymail \
|
||||
--dataset_config "3.0.0" \
|
||||
--source_prefix "summarize: " \
|
||||
--push_to_hub \
|
||||
--push_to_hub_model_id finetuned-t5-cnn_dailymail \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--overwrite_output_dir \
|
||||
--predict_with_generate
|
||||
```
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,8 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
||||
|
@ -52,7 +52,8 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.42.0.dev0")
|
||||
check_min_version("4.43.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
4
setup.py
4
setup.py
@ -128,7 +128,7 @@ _deps = [
|
||||
"kenlm",
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
"librosa",
|
||||
"nltk",
|
||||
"natten>=0.14.6,<0.15.0",
|
||||
@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.42.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.42.0.dev0"
|
||||
__version__ = "4.43.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -435,6 +435,7 @@ _import_structure = {
|
||||
],
|
||||
"models.fuyu": ["FuyuConfig"],
|
||||
"models.gemma": ["GemmaConfig"],
|
||||
"models.gemma2": ["Gemma2Config"],
|
||||
"models.git": [
|
||||
"GitConfig",
|
||||
"GitProcessor",
|
||||
@ -1211,6 +1212,7 @@ else:
|
||||
"Cache",
|
||||
"CacheConfig",
|
||||
"DynamicCache",
|
||||
"EncoderDecoderCache",
|
||||
"HQQQuantizedCache",
|
||||
"QuantizedCache",
|
||||
"QuantizedCacheConfig",
|
||||
@ -2181,6 +2183,15 @@ else:
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gemma2"].extend(
|
||||
[
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma2ForSequenceClassification",
|
||||
"Gemma2ForTokenClassification",
|
||||
"Gemma2Model",
|
||||
"Gemma2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.git"].extend(
|
||||
[
|
||||
"GitForCausalLM",
|
||||
@ -5062,6 +5073,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.fuyu import FuyuConfig
|
||||
from .models.gemma import GemmaConfig
|
||||
from .models.gemma2 import Gemma2Config
|
||||
from .models.git import (
|
||||
GitConfig,
|
||||
GitProcessor,
|
||||
@ -5884,6 +5896,7 @@ if TYPE_CHECKING:
|
||||
Cache,
|
||||
CacheConfig,
|
||||
DynamicCache,
|
||||
EncoderDecoderCache,
|
||||
HQQQuantizedCache,
|
||||
QuantizedCache,
|
||||
QuantizedCacheConfig,
|
||||
@ -6694,6 +6707,13 @@ if TYPE_CHECKING:
|
||||
GemmaModel,
|
||||
GemmaPreTrainedModel,
|
||||
)
|
||||
from .models.gemma2 import (
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2ForSequenceClassification,
|
||||
Gemma2ForTokenClassification,
|
||||
Gemma2Model,
|
||||
Gemma2PreTrainedModel,
|
||||
)
|
||||
from .models.git import (
|
||||
GitForCausalLM,
|
||||
GitModel,
|
||||
|
@ -858,8 +858,12 @@ class StaticCache(Cache):
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
if cache_position is None:
|
||||
k_out.copy_(key_states)
|
||||
v_out.copy_(value_states)
|
||||
else:
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
@ -970,3 +974,277 @@ class SlidingWindowCache(StaticCache):
|
||||
# in theory there is no limit because the sliding window size is fixed
|
||||
# no matter how long the sentence is
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
self.key_cache.zero_()
|
||||
self.value_cache.zero_()
|
||||
|
||||
|
||||
class EncoderDecoderCache(Cache):
|
||||
"""
|
||||
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
|
||||
cross-attention caches.
|
||||
"""
|
||||
|
||||
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
|
||||
self.self_attention_cache = self_attention_cache
|
||||
self.cross_attention_cache = cross_attention_cache
|
||||
|
||||
self.is_updated = {}
|
||||
for layer_idx in range(len(cross_attention_cache.key_cache)):
|
||||
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||
sequence length.
|
||||
"""
|
||||
if layer_idx < len(self):
|
||||
return (
|
||||
self.self_attention_cache.key_cache[layer_idx],
|
||||
self.self_attention_cache.value_cache[layer_idx],
|
||||
self.cross_attention_cache.key_cache[layer_idx],
|
||||
self.cross_attention_cache.key_cache[layer_idx],
|
||||
)
|
||||
else:
|
||||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
||||
to the number of layers in the model.
|
||||
"""
|
||||
return len(self.self_attention_cache)
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||
legacy_cache = ()
|
||||
if len(self.cross_attention_cache) > 0:
|
||||
for self_attn, cross_attn in zip(
|
||||
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
|
||||
):
|
||||
legacy_cache += (self_attn + cross_attn,)
|
||||
else:
|
||||
legacy_cache = self.self_attention_cache.to_legacy_cache()
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
) -> "EncoderDecoderCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
|
||||
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx][:2]
|
||||
cache.self_attention_cache.update(key_states, value_states, layer_idx)
|
||||
if len(past_key_values[layer_idx]) > 2:
|
||||
key_states, value_states = past_key_values[layer_idx][2:]
|
||||
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
|
||||
cache.is_updated[layer_idx] = True
|
||||
return cache
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
if len(self.self_attention_cache.key_cache) <= layer_idx:
|
||||
return 0
|
||||
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
def reset(self):
|
||||
if hasattr(self.self_attention_cache, "reset"):
|
||||
self.self_attention_cache.reset()
|
||||
if hasattr(self.cross_attention_cache, "reset"):
|
||||
self.cross_attention_cache.reset()
|
||||
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
|
||||
raise ValueError(
|
||||
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
|
||||
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
|
||||
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
|
||||
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
|
||||
)
|
||||
for layer_idx in self.is_updated:
|
||||
self.is_updated[layer_idx] = False
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||
self.self_attention_cache.reorder_cache(beam_idx)
|
||||
self.cross_attention_cache.reorder_cache(beam_idx)
|
||||
|
||||
def check_dynamic_cache(self, method: str):
|
||||
if not (
|
||||
isinstance(self.self_attention_cache, DynamicCache)
|
||||
and isinstance(self.cross_attention_cache, DynamicCache)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
|
||||
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
|
||||
)
|
||||
|
||||
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
|
||||
def crop(self, maximum_length: int):
|
||||
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
||||
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
|
||||
self.check_dynamic_cache(self.crop.__name__)
|
||||
self.self_attention_cache.crop(maximum_length)
|
||||
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
self.check_dynamic_cache(self.batch_split.__name__)
|
||||
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
|
||||
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
|
||||
|
||||
out = []
|
||||
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
|
||||
out.append(EncoderDecoderCache(self_attn, cross_attn))
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
self_attention_cache = DynamicCache()
|
||||
cross_attention_cache = DynamicCache()
|
||||
for idx in range(len(splits[0])):
|
||||
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
|
||||
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
|
||||
self_attention_cache.update(layer_keys, layer_values, idx)
|
||||
|
||||
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
|
||||
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
|
||||
cross_attention_cache.update(layer_keys, layer_values, idx)
|
||||
return cls(self_attention_cache, cross_attention_cache)
|
||||
|
||||
def batch_repeat_interleave(self, repeats: int):
|
||||
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
||||
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
|
||||
self.self_attention_cache.batch_repeat_interleave(repeats)
|
||||
self.cross_attention_cache.batch_repeat_interleave(repeats)
|
||||
|
||||
def batch_select_indices(self, indices: torch.Tensor):
|
||||
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
||||
self.check_dynamic_cache(self.batch_select_indices.__name__)
|
||||
self.self_attention_cache.batch_select_indices(indices)
|
||||
self.cross_attention_cache.batch_select_indices(indices)
|
||||
|
||||
|
||||
class HybridCache(Cache):
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
raise ValueError(
|
||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
||||
"config and it's not set to None."
|
||||
)
|
||||
self.max_cache_len = max_cache_len
|
||||
self.max_batch_size = max_batch_size
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
self.is_sliding = torch.tensor(
|
||||
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
|
||||
)
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
max_batch_size,
|
||||
self.num_key_value_heads,
|
||||
min(config.sliding_window, max_cache_len),
|
||||
self.head_dim,
|
||||
)
|
||||
for i in range(config.num_hidden_layers):
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
k_out = key_states[:, :, -max_cache_len:, :]
|
||||
v_out = value_states[:, :, -max_cache_len:, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
|
||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
to_shift = cache_position >= max_cache_len - 1
|
||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
return k_out, v_out
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
self.key_cache[layer_idx] = k_out
|
||||
self.value_cache[layer_idx] = v_out
|
||||
return k_out, v_out
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
if sliding_window:
|
||||
update_fn = self._sliding_update
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
key_states,
|
||||
value_states,
|
||||
k_out,
|
||||
v_out,
|
||||
k_out.shape[2],
|
||||
)
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
# in theory there is no limit because the sliding window size is fixed
|
||||
# no matter how long the sentence is
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Resets the cache values while preserving the objects"""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
@ -34,7 +34,7 @@ deps = {
|
||||
"jinja2": "jinja2>=3.1.0",
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"librosa": "librosa",
|
||||
"nltk": "nltk",
|
||||
"natten": "natten>=0.14.6,<0.15.0",
|
||||
|
@ -400,7 +400,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Cache implementation
|
||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||
self.cache_config = kwargs.pop("cache_config", None)
|
||||
if self.cache_implementation is not None:
|
||||
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
||||
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
||||
if self.cache_config is None:
|
||||
self.cache_config = cache_config_class()
|
||||
|
@ -27,7 +27,9 @@ from torch import nn
|
||||
from ..cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
EncoderDecoderCache,
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
SlidingWindowCache,
|
||||
@ -112,7 +114,7 @@ logger = logging.get_logger(__name__)
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
||||
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
|
||||
|
||||
@ -1395,10 +1397,12 @@ class GenerationMixin:
|
||||
|
||||
past_length = 0
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
if isinstance(model_kwargs["past_key_values"], Cache):
|
||||
past_length = model_kwargs["past_key_values"].get_seq_length()
|
||||
else:
|
||||
past_length = model_kwargs["past_key_values"][0][0].shape[2]
|
||||
cache = model_kwargs["past_key_values"]
|
||||
if not isinstance(cache, Cache):
|
||||
past_length = cache[0][0].shape[2]
|
||||
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
||||
past_length = cache.get_seq_length()
|
||||
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
else:
|
||||
@ -1406,7 +1410,7 @@ class GenerationMixin:
|
||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||
return model_kwargs
|
||||
|
||||
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
|
||||
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
|
||||
"""
|
||||
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
||||
new `generate` call requires a larger cache.
|
||||
@ -1414,28 +1418,46 @@ class GenerationMixin:
|
||||
Returns the resulting cache object.
|
||||
"""
|
||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
|
||||
if hasattr(self, "_cache"):
|
||||
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
||||
|
||||
if cache_implementation == "sliding_window":
|
||||
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
||||
|
||||
need_new_cache = (
|
||||
not hasattr(self, "_cache")
|
||||
or (not isinstance(self._cache, cache_cls))
|
||||
or self._cache.max_batch_size != max_batch_size
|
||||
or self._cache.max_cache_len < max_cache_len
|
||||
or (not isinstance(cache_to_check, cache_cls))
|
||||
or cache_to_check.max_batch_size != max_batch_size
|
||||
or cache_to_check.max_cache_len < max_cache_len
|
||||
)
|
||||
|
||||
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
||||
need_new_cache = (
|
||||
need_new_cache
|
||||
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
|
||||
)
|
||||
|
||||
if need_new_cache:
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
cache_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
cache_dtype = self.dtype
|
||||
self._cache = cache_cls(
|
||||
config=self.config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=self.device,
|
||||
dtype=cache_dtype,
|
||||
)
|
||||
cache_kwargs = {
|
||||
"config": self.config,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": self.device,
|
||||
"dtype": cache_dtype,
|
||||
}
|
||||
self._cache = cache_cls(**cache_kwargs)
|
||||
if requires_cross_attention_cache:
|
||||
encoder_kwargs = cache_kwargs.copy()
|
||||
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
|
||||
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
|
||||
else:
|
||||
self._cache.reset()
|
||||
return self._cache
|
||||
@ -1471,8 +1493,11 @@ class GenerationMixin:
|
||||
device = self.device
|
||||
|
||||
token = token_kwargs if token_kwargs is not None else token_self
|
||||
if token is None or isinstance(token, torch.Tensor):
|
||||
if token is None:
|
||||
return token
|
||||
elif isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_id = _tensor_or_none(
|
||||
@ -1739,7 +1764,10 @@ class GenerationMixin:
|
||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||
)
|
||||
model_kwargs["past_key_values"] = self._get_cache(
|
||||
generation_config.cache_implementation, batch_size, generation_config.max_length
|
||||
generation_config.cache_implementation,
|
||||
getattr(generation_config, "num_beams", 1) * batch_size,
|
||||
generation_config.max_length,
|
||||
model_kwargs,
|
||||
)
|
||||
elif generation_config.cache_implementation == "quantized":
|
||||
if not self._supports_quantized_cache:
|
||||
@ -1771,11 +1799,22 @@ class GenerationMixin:
|
||||
# keeps copying the cache thus using much more memory
|
||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
||||
past = model_kwargs.get("past_key_values", None)
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
if past is None:
|
||||
model_kwargs["past_key_values"] = DynamicCache()
|
||||
model_kwargs["past_key_values"] = (
|
||||
DynamicCache()
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
)
|
||||
use_dynamic_cache_by_default = True
|
||||
elif isinstance(past, tuple):
|
||||
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
|
||||
model_kwargs["past_key_values"] = (
|
||||
DynamicCache.from_legacy_cache(past)
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache.from_legacy_cache(past)
|
||||
)
|
||||
use_dynamic_cache_by_default = True
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
@ -2059,7 +2098,7 @@ class GenerationMixin:
|
||||
# Convert to legacy cache if needed
|
||||
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
|
||||
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
|
||||
if isinstance(result.past_key_values, DynamicCache):
|
||||
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||
return result
|
||||
|
||||
@ -2229,7 +2268,7 @@ class GenerationMixin:
|
||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||
if model_kwargs.get("past_key_values") is None or (
|
||||
isinstance(model_kwargs["past_key_values"], Cache)
|
||||
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
|
||||
and model_kwargs["past_key_values"].get_seq_length() == 0
|
||||
):
|
||||
# prepare inputs
|
||||
@ -2318,7 +2357,9 @@ class GenerationMixin:
|
||||
# Replicates the new past_key_values to match the `top_k` candidates
|
||||
past = model_kwargs["past_key_values"]
|
||||
# If it is a static cache, modify it in-place layer after layer to save memory
|
||||
if isinstance(past, DynamicCache):
|
||||
if isinstance(past, DynamicCache) or (
|
||||
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
||||
):
|
||||
past.batch_repeat_interleave(top_k)
|
||||
else:
|
||||
new_key_values = []
|
||||
@ -2345,7 +2386,10 @@ class GenerationMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if isinstance(outputs["past_key_values"], DynamicCache):
|
||||
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
||||
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
||||
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
|
||||
):
|
||||
# Remove past K-V from output since we don't need to stack later
|
||||
outputs["past_key_values"] = None
|
||||
# Remove last token from past K-V since we don't want to append it at this point
|
||||
@ -2420,7 +2464,10 @@ class GenerationMixin:
|
||||
else:
|
||||
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||
# Do it in-place layer per layer to save memory
|
||||
if isinstance(next_past_key_values, DynamicCache):
|
||||
if isinstance(next_past_key_values, DynamicCache) or (
|
||||
isinstance(next_past_key_values, EncoderDecoderCache)
|
||||
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
||||
):
|
||||
next_past_key_values.batch_select_indices(augmented_idx)
|
||||
else:
|
||||
new_key_values = []
|
||||
@ -2493,7 +2540,10 @@ class GenerationMixin:
|
||||
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
||||
# `past_key_values` to be consistent with the other decoding methods
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
if isinstance(model_kwargs["past_key_values"], DynamicCache):
|
||||
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
|
||||
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
||||
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
||||
):
|
||||
model_kwargs["past_key_values"].crop(-1)
|
||||
else:
|
||||
past_key_values = []
|
||||
@ -2752,7 +2802,7 @@ class GenerationMixin:
|
||||
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
|
||||
# cache format is standardized, to avoid adding complexity to the codebase.
|
||||
elif "bloom" in model_class or "gptbigcode" in model_class:
|
||||
if not isinstance(past_key_values, DynamicCache):
|
||||
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||
raise ValueError(
|
||||
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
|
||||
"legacy tuple format or `DynamicCache`"
|
||||
@ -3698,8 +3748,12 @@ class GenerationMixin:
|
||||
|
||||
# This is needed if return_dict_in_generate is True
|
||||
start_from_empty_dynamic_cache = False
|
||||
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
|
||||
if len(model_kwargs["past_key_values"]) == 0:
|
||||
past_key_values = model_kwargs.get("past_key_values", None)
|
||||
if isinstance(past_key_values, DynamicCache) or (
|
||||
isinstance(past_key_values, EncoderDecoderCache)
|
||||
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
||||
):
|
||||
if len(past_key_values) == 0:
|
||||
start_from_empty_dynamic_cache = True
|
||||
|
||||
this_peer_finished = False
|
||||
@ -4017,7 +4071,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||
# New cache format
|
||||
elif isinstance(data, DynamicCache):
|
||||
elif isinstance(data, DynamicCache) or (
|
||||
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||
):
|
||||
return data.batch_split(full_batch_size, split_size)
|
||||
elif isinstance(data, tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
@ -4125,6 +4181,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
# New cache format
|
||||
elif isinstance(data[0], DynamicCache):
|
||||
return DynamicCache.from_batch_splits(data)
|
||||
elif isinstance(data[0], EncoderDecoderCache):
|
||||
return EncoderDecoderCache.from_batch_splits(data)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
|
@ -92,6 +92,7 @@ from . import (
|
||||
funnel,
|
||||
fuyu,
|
||||
gemma,
|
||||
gemma2,
|
||||
git,
|
||||
glpn,
|
||||
gpt2,
|
||||
|
@ -108,6 +108,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("funnel", "FunnelConfig"),
|
||||
("fuyu", "FuyuConfig"),
|
||||
("gemma", "GemmaConfig"),
|
||||
("gemma2", "Gemma2Config"),
|
||||
("git", "GitConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
("gpt-sw3", "GPT2Config"),
|
||||
@ -385,6 +386,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("funnel", "Funnel Transformer"),
|
||||
("fuyu", "Fuyu"),
|
||||
("gemma", "Gemma"),
|
||||
("gemma2", "Gemma2"),
|
||||
("git", "GIT"),
|
||||
("glpn", "GLPN"),
|
||||
("gpt-sw3", "GPT-Sw3"),
|
||||
|
@ -105,6 +105,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("fsmt", "FSMTModel"),
|
||||
("funnel", ("FunnelModel", "FunnelBaseModel")),
|
||||
("gemma", "GemmaModel"),
|
||||
("gemma2", "Gemma2Model"),
|
||||
("git", "GitModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
("gpt-sw3", "GPT2Model"),
|
||||
@ -454,6 +455,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
("gemma2", "Gemma2ForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("gpt-sw3", "GPT2LMHeadModel"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
@ -863,6 +865,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("fnet", "FNetForSequenceClassification"),
|
||||
("funnel", "FunnelForSequenceClassification"),
|
||||
("gemma", "GemmaForSequenceClassification"),
|
||||
("gemma2", "Gemma2ForSequenceClassification"),
|
||||
("gpt-sw3", "GPT2ForSequenceClassification"),
|
||||
("gpt2", "GPT2ForSequenceClassification"),
|
||||
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
|
||||
@ -1044,6 +1047,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("fnet", "FNetForTokenClassification"),
|
||||
("funnel", "FunnelForTokenClassification"),
|
||||
("gemma", "GemmaForTokenClassification"),
|
||||
("gemma2", "Gemma2ForTokenClassification"),
|
||||
("gpt-sw3", "GPT2ForTokenClassification"),
|
||||
("gpt2", "GPT2ForTokenClassification"),
|
||||
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
||||
|
@ -188,6 +188,13 @@ else:
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"gemma2",
|
||||
(
|
||||
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
"""
|
||||
return_dict = return_dict or self.config.return_dict
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
chunk_length = self.config.chunk_length
|
||||
if chunk_length is None:
|
||||
@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
||||
>>> audio_codes = outputs.audio_codes
|
||||
>>> audio_values = outputs.audio_values
|
||||
```"""
|
||||
return_dict = return_dict or self.config.return_dict
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.ones_like(input_values).bool()
|
||||
|
@ -257,6 +257,7 @@ class GemmaAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.scaling = 1 / math.sqrt(config.head_dim)
|
||||
|
||||
if self.hidden_size % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -305,7 +306,7 @@ class GemmaAttention(nn.Module):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
@ -240,6 +240,7 @@ class GemmaAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.scaling = 1 / math.sqrt(config.head_dim)
|
||||
|
||||
if self.hidden_size % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -265,7 +266,7 @@ class GemmaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -280,6 +281,13 @@ class GemmaAttention(nn.Module):
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if q_len > 1:
|
||||
# prefill
|
||||
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
|
||||
else:
|
||||
# decoding
|
||||
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
@ -288,7 +296,7 @@ class GemmaAttention(nn.Module):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
@ -339,7 +347,7 @@ class GemmaFlashAttention2(GemmaAttention):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
@ -365,6 +373,13 @@ class GemmaFlashAttention2(GemmaAttention):
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if q_len > 1:
|
||||
# prefill
|
||||
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
|
||||
else:
|
||||
# decoding
|
||||
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
@ -530,7 +545,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -545,7 +560,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
cache_length=cache_length,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -561,6 +576,13 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if q_len > 1:
|
||||
# prefill
|
||||
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
|
||||
else:
|
||||
# decoding
|
||||
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
@ -584,6 +606,11 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
if cache_length > 0:
|
||||
key_states = key_states[:, :, :cache_length, :]
|
||||
value_states = value_states[:, :, :cache_length, :]
|
||||
causal_mask = causal_mask[:, :, :, :cache_length] if causal_mask is not None else causal_mask
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -627,7 +654,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -661,7 +688,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
cache_length=cache_length,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -849,7 +876,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -877,17 +904,18 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
return_legacy_cache = True # noqa: F841
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_length is None:
|
||||
cache_length = past_seen_tokens + inputs_embeds.shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, cache_length, device=inputs_embeds.device
|
||||
)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
attention_mask, inputs_embeds, cache_length, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# embed positions
|
||||
@ -898,6 +926,13 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
|
||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -917,7 +952,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
cache_length,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -927,7 +962,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
cache_length=cache_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -961,7 +996,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
cache_length: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
@ -1009,12 +1044,12 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
else:
|
||||
# This computation is only required when `sequence_length = 1` in the case of static cache.
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_length
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
@ -1082,7 +1117,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_length: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1126,7 +1161,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
cache_length=cache_length,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1163,14 +1198,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
cached_length=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
past_length = past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
@ -1215,15 +1250,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
if cached_length is None:
|
||||
# It must be a python int
|
||||
cached_length = int(past_length + input_length)
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"cache_length": cached_length,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
@ -1397,7 +1431,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
@ -1407,7 +1441,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
) -> Union[Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
|
61
src/transformers/models/gemma2/__init__.py
Normal file
61
src/transformers/models/gemma2/__init__.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gemma2": ["Gemma2Config"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gemma2"] = [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma2Model",
|
||||
"Gemma2PreTrainedModel",
|
||||
"Gemma2ForSequenceClassification",
|
||||
"Gemma2ForTokenClassification",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gemma import Gemma2Config
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_gemma import (
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2ForSequenceClassification,
|
||||
Gemma2ForTokenClassification,
|
||||
Gemma2Model,
|
||||
Gemma2PreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
152
src/transformers/models/gemma2/configuration_gemma2.py
Normal file
152
src/transformers/models/gemma2/configuration_gemma2.py
Normal file
@ -0,0 +1,152 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class Gemma2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma2-7B.
|
||||
e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Gemma2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
```python
|
||||
>>> from transformers import Gemma2Model, Gemma2Config
|
||||
>>> # Initializing a Gemma2 gemma2-9b style configuration
|
||||
>>> configuration = Gemma2Config()
|
||||
>>> # Initializing a model from the gemma2-9b style configuration
|
||||
>>> model = Gemma2Model(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.cache_implementation = "hybrid"
|
239
src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py
Normal file
239
src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from transformers import GemmaTokenizerFast
|
||||
except ImportError as e:
|
||||
warnings.warn(e)
|
||||
warnings.warn(
|
||||
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
||||
)
|
||||
GemmaTokenizerFast = None
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/gemma/weights --model_size 9B --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import Gemma2ForCausalLM, GemmaTokenizerFast
|
||||
|
||||
model = Gemma2ForCausalLM.from_pretrained("/output/path")
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
"""
|
||||
|
||||
gemma_9b_config = Gemma2Config(
|
||||
num_hidden_layers=42,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
hidden_size=3584,
|
||||
intermediate_size=14336,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
head_dim=256,
|
||||
sliding_window=4096,
|
||||
query_pre_attn_scalar=224,
|
||||
)
|
||||
|
||||
gemma_27b_config = Gemma2Config(
|
||||
num_hidden_layers=46,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=16,
|
||||
hidden_size=4608,
|
||||
intermediate_size=36864,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
head_dim=128,
|
||||
sliding_window=4096,
|
||||
query_pre_attn_scalar=144,
|
||||
)
|
||||
|
||||
CONFIG_MAPPING = {"9B": gemma_9b_config, "27B": gemma_27b_config}
|
||||
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
|
||||
|
||||
|
||||
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32):
|
||||
num_attn_heads = config.num_attention_heads
|
||||
hidden_size = config.hidden_size
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
head_dim = config.head_dim
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
|
||||
|
||||
if os.path.isdir(input_base_path):
|
||||
print("Model seems sharded")
|
||||
|
||||
model_state_dict = {}
|
||||
files = [file for file in os.listdir(input_base_path) if file.endswith(".bin")]
|
||||
|
||||
for file in files:
|
||||
print(file)
|
||||
loaded_state_dict = torch.load(os.path.join(input_base_path, file), map_location="cpu")
|
||||
model_state_dict.update(loaded_state_dict)
|
||||
else:
|
||||
print("Model does not seem to be sharded")
|
||||
model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"]
|
||||
model_state_dict.pop("freqs_cis")
|
||||
|
||||
state_dict = {}
|
||||
for k, v in model_state_dict.items():
|
||||
if "qkv_proj" in k:
|
||||
if num_kv_heads == 1:
|
||||
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, hidden_size)
|
||||
q_proj = v[:num_attn_heads, ...]
|
||||
k_proj = v[num_attn_heads : num_attn_heads + num_kv_heads, ...].repeat(num_kv_heads, 1, 1)
|
||||
v_proj = v[-num_kv_heads:, ...].repeat(num_kv_heads, 1, 1)
|
||||
|
||||
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
|
||||
num_attn_heads * head_dim, hidden_size
|
||||
).clone()
|
||||
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
|
||||
num_kv_heads * head_dim, hidden_size
|
||||
).clone()
|
||||
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj[0].clone()
|
||||
else:
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
v, [num_attn_heads * head_dim, num_kv_heads * head_dim, num_kv_heads * head_dim], 0
|
||||
)
|
||||
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
|
||||
num_attn_heads * head_dim, hidden_size
|
||||
).clone()
|
||||
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
|
||||
num_kv_heads * head_dim, hidden_size
|
||||
).clone()
|
||||
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj.reshape(
|
||||
num_kv_heads * head_dim, hidden_size
|
||||
).clone()
|
||||
|
||||
elif k == "embedder.weight":
|
||||
state_dict[LAYER_NAME_MAPPING[k]] = v
|
||||
state_dict["lm_head.weight"] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
print("Loading the checkpoint in a Gemma2 model.")
|
||||
with init_empty_weights():
|
||||
model = Gemma2ForCausalLM(config)
|
||||
model.load_state_dict(state_dict, assign=True, strict=False)
|
||||
|
||||
model.config.torch_dtype = torch.float32
|
||||
del model.config._name_or_path
|
||||
print("Saving in the Transformers format.")
|
||||
|
||||
if push_to_hub:
|
||||
print(f"pushing the model to {save_path}")
|
||||
model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True)
|
||||
else:
|
||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False):
|
||||
# Initialize the tokenizer based on the `spm` model
|
||||
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
|
||||
print(f"Saving a {tokenizer_class.__name__} to {save_path}.")
|
||||
tokenizer = tokenizer_class(input_tokenizer_path)
|
||||
if push_to_hub:
|
||||
tokenizer.push_to_hub(save_path)
|
||||
else:
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_checkpoint",
|
||||
help="Absolute path to the target Gemma2 weights.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_checkpoint",
|
||||
help="Location of Gemma2 tokenizer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
default="9B",
|
||||
choices=["9B", "27B", "tokenizer_only"],
|
||||
help="'f' models correspond to the finetuned versions, and are specific to the Gemma22 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="google/gemma-9b",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pickle_serialization",
|
||||
help="Whether or not to save using `safetensors`.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--convert_tokenizer",
|
||||
help="Whether or not to convert the tokenizer as well.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
help="Target dtype of the converted model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.convert_tokenizer:
|
||||
if args.tokenizer_checkpoint is None:
|
||||
raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")
|
||||
|
||||
spm_path = os.path.join(args.tokenizer_checkpoint)
|
||||
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
|
||||
if not args.model_size == "tokenizer_only":
|
||||
config = CONFIG_MAPPING[args.model_size]
|
||||
dtype = getattr(torch, args.dtype)
|
||||
write_model(
|
||||
config=config,
|
||||
input_base_path=args.input_checkpoint,
|
||||
save_path=args.output_dir,
|
||||
safe_serialization=not args.pickle_serialization,
|
||||
push_to_hub=args.push_to_hub,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
781
src/transformers/models/gemma2/diff_gemma2.py
Normal file
781
src/transformers/models/gemma2/diff_gemma2.py
Normal file
@ -0,0 +1,781 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.gemma.configuration_gemma import GemmaConfig
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaDecoderLayer,
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma2Config(GemmaConfig):
|
||||
cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=30.0,
|
||||
**super_kwargs,
|
||||
):
|
||||
super().__init__(self, **super_kwargs)
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.cache_implementation = "hybrid"
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
|
||||
|
||||
class Gemma2RMSNorm(GemmaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma2Attention(GemmaAttention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
|
||||
class Gemma2FlashAttention2(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (Gemma2RMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
def _flash_attention_forward(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=0.0,
|
||||
softmax_scale=None,
|
||||
cache_position=0,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Gemma2FlashAttention2 __init__.
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
|
||||
)
|
||||
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {}
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
class Gemma2SdpaAttention(Gemma2Attention):
|
||||
"""
|
||||
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Gemma2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class Gemma2DecoderLayer(GemmaDecoderLayer):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.is_sliding = bool(layer_idx % 2)
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
attention_mask = attention_mask * torch.tril(
|
||||
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1])
|
||||
)
|
||||
if cache_position[0] > 0:
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Gemma2Model(GemmaModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# normalized
|
||||
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = past_key_values if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if past_key_values is not None:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1]
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
||||
if attention_mask.max() != 0:
|
||||
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device)
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma2ForTokenClassification(GemmaForTokenClassification):
|
||||
pass
|
1402
src/transformers/models/gemma2/modeling_gemma2.py
Normal file
1402
src/transformers/models/gemma2/modeling_gemma2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1344,9 +1344,6 @@ class Idefics2PreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Idefics2 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/idefics2 should serve for that purpose
|
||||
std = (
|
||||
self.config.text_config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
|
@ -131,23 +131,5 @@ class LlavaConfig(PretrainedConfig):
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self._vocab_size = self.text_config.vocab_size
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
warnings.warn(
|
||||
"The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._vocab_size
|
||||
|
||||
@vocab_size.setter
|
||||
def vocab_size(self, value):
|
||||
self._vocab_size = value
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("_vocab_size", None)
|
||||
return output
|
||||
|
@ -227,7 +227,6 @@ class MistralAttention(nn.Module):
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1090,8 +1089,9 @@ class MistralModel(MistralPreTrainedModel):
|
||||
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if self.config.sliding_window is not None:
|
||||
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
|
||||
exclude_mask |= torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - self.config.sliding_window
|
||||
exclude_mask.bitwise_or_(
|
||||
torch.arange(target_length, device=device)
|
||||
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
|
||||
)
|
||||
causal_mask *= exclude_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
|
@ -1276,7 +1276,7 @@ class Owlv2ClassPredictionHead(nn.Module):
|
||||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
return (pred_logits, image_class_embeds)
|
||||
|
@ -1257,7 +1257,7 @@ class OwlViTClassPredictionHead(nn.Module):
|
||||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
return (pred_logits, image_class_embeds)
|
||||
|
@ -129,9 +129,6 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["VideoLlavaVisionAttention"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of VideoLlava isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/video_llava should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
|
@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""VipLlava model configuration"""
|
||||
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
@ -90,13 +88,6 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.projector_layernorm_eps = projector_layernorm_eps
|
||||
self.vision_feature_layers = vision_feature_layers
|
||||
|
||||
if "vocab_size" in kwargs:
|
||||
warnings.warn(
|
||||
"The `vocab_size` argument is deprecated and will be removed in v4.42, since it can be inferred from the `text_config`. Passing this argument has no effect",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(self.vision_config, dict):
|
||||
@ -123,19 +114,5 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self._vocab_size = self.text_config.vocab_size
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
warnings.warn(
|
||||
"The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._vocab_size
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("_vocab_size", None)
|
||||
return output
|
||||
|
@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig):
|
||||
|
||||
model_type = "whisper"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
attribute_map = {
|
||||
"num_key_value_heads": "encoder_attention_heads",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
"hidden_size": "d_model",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -25,7 +25,8 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -244,6 +245,7 @@ class WhisperAttention(nn.Module):
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
config: Optional[WhisperConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -262,6 +264,14 @@ class WhisperAttention(nn.Module):
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
|
||||
if layer_idx is None and is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
@ -271,84 +281,56 @@ class WhisperAttention(nn.Module):
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -358,42 +340,27 @@ class WhisperAttention(nn.Module):
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value_states)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper
|
||||
class WhisperFlashAttention2(WhisperAttention):
|
||||
"""
|
||||
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
|
||||
@ -410,18 +377,21 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
||||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
# WhisperFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
|
||||
@ -429,51 +399,45 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0].transpose(1, 2)
|
||||
value_states = past_key_value[1].transpose(1, 2)
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
@ -502,10 +466,10 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
||||
query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -614,15 +578,15 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
|
||||
|
||||
class WhisperSdpaAttention(WhisperAttention):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
@ -638,59 +602,50 @@ class WhisperSdpaAttention(WhisperAttention):
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
@ -698,7 +653,7 @@ class WhisperSdpaAttention(WhisperAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
@ -798,9 +753,8 @@ class WhisperEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
|
||||
class WhisperDecoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig):
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
@ -810,6 +764,7 @@ class WhisperDecoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
is_causal=True,
|
||||
layer_idx=layer_idx,
|
||||
config=config,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
@ -822,6 +777,7 @@ class WhisperDecoderLayer(nn.Module):
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
layer_idx=layer_idx,
|
||||
config=config,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
@ -837,9 +793,10 @@ class WhisperDecoderLayer(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -863,41 +820,35 @@ class WhisperDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
# add cross-attn to positions 1 of present_key_value tuple
|
||||
present_key_value = (present_key_value, cross_attn_present_key_value)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
@ -927,6 +878,8 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -1024,14 +977,18 @@ WHISPER_INPUTS_DOCSTRING = r"""
|
||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
|
||||
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
|
||||
when `config.use_cache=True`
|
||||
|
||||
Two formats are allowed:
|
||||
- An [`~cache_utils.EncoderDecoderCache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
@ -1051,6 +1008,9 @@ WHISPER_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
|
||||
in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
|
||||
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
@ -1256,7 +1216,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
|
||||
|
||||
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
|
||||
)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
@ -1286,6 +1248,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -1320,13 +1283,17 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
|
||||
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
|
||||
when `config.use_cache=True`
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
Two formats are allowed:
|
||||
- An [`~cache_utils.EncoderDecoderCache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
@ -1344,6 +1311,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1363,26 +1333,38 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
return_legacy_cache = False
|
||||
return_self_attention_cache = False
|
||||
if use_cache or past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
|
||||
return_self_attention_cache = True
|
||||
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||
elif not isinstance(past_key_values, EncoderDecoderCache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
past_key_values_length = 0
|
||||
if cache_position is not None:
|
||||
past_key_values_length = cache_position[0]
|
||||
elif past_key_values is not None:
|
||||
past_key_values_length = past_key_values.get_seq_length()
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(
|
||||
@ -1396,6 +1378,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values.self_attention_cache if past_key_values is not None else None,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
@ -1406,7 +1396,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
@ -1424,13 +1413,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
None, # encoder attention mask
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
@ -1438,25 +1425,24 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
None, # past_key_value
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values if use_cache else None,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
@ -1468,7 +1454,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = past_key_values if use_cache else None
|
||||
if return_self_attention_cache:
|
||||
next_cache = past_key_values.self_attention_cache
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
@ -1483,6 +1473,87 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
||||
if attention_mask.max() != 0:
|
||||
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
|
||||
@ -1571,13 +1642,14 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@ -1637,6 +1709,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -1704,7 +1777,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
|
||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -1712,6 +1785,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1766,6 +1840,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
lm_logits = self.proj_out(outputs[0])
|
||||
|
||||
@ -1800,14 +1875,19 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
decoder_position_ids = None
|
||||
if decoder_attention_mask is not None:
|
||||
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
|
||||
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
else:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
@ -1821,6 +1901,13 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
|
||||
)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-decoder_input_ids.shape[1] :]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
@ -1828,6 +1915,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
"use_cache": use_cache,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_position_ids": decoder_position_ids,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -1914,6 +2002,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1968,6 +2057,9 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
|
||||
in the correct position and to infer the complete sequence length.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -2019,6 +2111,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = self.proj_out(outputs[0])
|
||||
@ -2049,10 +2142,15 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
if isinstance(past_key_values, (Cache, EncoderDecoderCache)):
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
else:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
@ -2063,12 +2161,18 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_ids.shape[1] :]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"input_ids": input_ids,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -266,31 +266,33 @@ class TextGenerationPipeline(Pipeline):
|
||||
prompt_text,
|
||||
prefix="",
|
||||
handle_long_generation=None,
|
||||
add_special_tokens=False,
|
||||
add_special_tokens=None,
|
||||
truncation=None,
|
||||
padding=False,
|
||||
padding=None,
|
||||
max_length=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
if isinstance(prompt_text, Chat):
|
||||
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||
tokenizer_kwargs = {}
|
||||
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
|
||||
if locals()[tokenizer_kwarg_name] is not None:
|
||||
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
prompt_text.messages,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors=self.framework,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
inputs = self.tokenizer(
|
||||
prefix + prompt_text,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||
tokenizer_kwargs = {}
|
||||
for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
|
||||
if locals()[tokenizer_kwarg_name] is not None:
|
||||
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
|
||||
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)
|
||||
|
||||
inputs["prompt_text"] = prompt_text
|
||||
|
||||
if handle_long_generation == "hole":
|
||||
|
@ -4605,6 +4605,11 @@ class Trainer:
|
||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
|
||||
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
|
||||
self.gather_function = functools.partial(
|
||||
self.gather_function, use_gather_object=self.args.eval_use_gather_object
|
||||
)
|
||||
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
|
@ -773,8 +773,11 @@ class TrainingArguments:
|
||||
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||
|
||||
eval_on_start(`bool`, *optional*, defaults to `False`):
|
||||
eval_on_start (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
|
||||
|
||||
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
|
||||
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@ -1465,6 +1468,13 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
eval_use_gather_object: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
@ -1992,6 +2002,12 @@ class TrainingArguments:
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
|
||||
raise ValueError(
|
||||
"--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0."
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
self_as_dict = asdict(self)
|
||||
|
||||
|
@ -80,7 +80,7 @@ def _parse_type_hint(hint: str) -> Dict:
|
||||
return_dict = subtypes[0]
|
||||
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
||||
# A union of basic types can be expressed as a list in the schema
|
||||
return_dict = {"type": [subtype["type"] for subtype in subtypes]}
|
||||
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
|
||||
else:
|
||||
# A union of more complex types requires "anyOf"
|
||||
return_dict = {"anyOf": subtypes}
|
||||
|
@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EncoderDecoderCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class HQQQuantizedCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -4197,6 +4204,41 @@ class GemmaPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma2ForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma2PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GitForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -762,7 +762,7 @@ def torch_int(x):
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
|
||||
def torch_float(x):
|
||||
@ -774,7 +774,7 @@ def torch_float(x):
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
|
||||
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
|
||||
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||
|
@ -30,7 +30,9 @@ from transformers.testing_utils import (
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -57,7 +59,7 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@ -1636,7 +1638,6 @@ class GenerationTesterMixin:
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
@ -1652,15 +1653,21 @@ class GenerationTesterMixin:
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
set_seed(seed)
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||
else:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
new_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
||||
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
|
||||
)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
# different
|
||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
|
||||
|
||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||
legacy_cache = legacy_results.past_key_values
|
||||
@ -1675,7 +1682,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
||||
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
@ -2082,6 +2089,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
[1, 18],
|
||||
)
|
||||
|
||||
# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
|
||||
def test_stop_sequence_stopping_criteria(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
prompt = """Hello I believe in"""
|
||||
@ -2089,17 +2097,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
output = generator(prompt)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
"Hello I believe in in in number number number number number number number number number"
|
||||
)
|
||||
}
|
||||
],
|
||||
[{"generated_text": ("Hello I believe in we we we we we we we we we")}],
|
||||
)
|
||||
|
||||
output = generator(prompt, stop_sequence=" number")
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||
output = generator(prompt, stop_sequence=" we")
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
|
||||
|
||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
|
||||
@ -3097,6 +3099,54 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
def test_assisted_decoding_in_different_gpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
"cuda:1"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_assisted_decoding_in_gpu_cpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
"cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
@ -19,7 +19,6 @@ import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import Audio, load_dataset
|
||||
@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
self.assertTrue(isinstance(tuple_output, tuple))
|
||||
self.assertTrue(isinstance(dict_output, dict))
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
|
||||
),
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
@ -47,11 +47,18 @@ if is_torch_available():
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaTokenizer,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GemmaModelTester:
|
||||
config_class = GemmaConfig
|
||||
if is_torch_available():
|
||||
model_class = GemmaModel
|
||||
for_causal_lm_class = GemmaForCausalLM
|
||||
for_sequence_class = GemmaForSequenceClassification
|
||||
for_token_class = GemmaForTokenClassification
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@ -129,9 +136,8 @@ class GemmaModelTester:
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
# Ignore copy
|
||||
def get_config(self):
|
||||
return GemmaConfig(
|
||||
return self.config_class(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@ -149,18 +155,16 @@ class GemmaModelTester:
|
||||
head_dim=self.head_dim,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = GemmaModel(config=config)
|
||||
model = self.model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
@ -174,7 +178,7 @@ class GemmaModelTester:
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = GemmaModel(config)
|
||||
model = self.model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
@ -191,7 +195,6 @@ class GemmaModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
@ -204,13 +207,12 @@ class GemmaModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = GemmaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
@ -225,7 +227,7 @@ class GemmaModelTester:
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = GemmaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@ -348,7 +350,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = GemmaForSequenceClassification(config)
|
||||
model = self.model_tester.for_sequence_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
@ -361,7 +363,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = GemmaForSequenceClassification(config)
|
||||
model = self.model_tester.for_sequence_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
@ -376,20 +378,19 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = GemmaForSequenceClassification(config)
|
||||
model = self.model_tester.for_sequence_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
|
||||
def test_Gemma_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = GemmaForTokenClassification(config=config)
|
||||
model = self.model_tester.for_token_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
@ -539,47 +540,9 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp32(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16_static_cache(self):
|
||||
model_id = "google/gemma-2b"
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
@ -903,7 +866,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
prompts = ["Hello I am doing", "Hi today"]
|
||||
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
|
||||
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
|
0
tests/models/gemma2/__init__.py
Normal file
0
tests/models/gemma2/__init__.py
Normal file
142
tests/models/gemma2/test_modeling_gemma2.py
Normal file
142
tests/models/gemma2/test_modeling_gemma2.py
Normal file
@ -0,0 +1,142 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma2 model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
||||
from ...test_configuration_common import ConfigTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2ForSequenceClassification,
|
||||
Gemma2ForTokenClassification,
|
||||
Gemma2Model,
|
||||
)
|
||||
|
||||
|
||||
class Gemma2ModelTester(GemmaModelTester):
|
||||
if is_torch_available():
|
||||
config_class = Gemma2Config
|
||||
model_class = Gemma2Model
|
||||
for_causal_lm_class = Gemma2ForCausalLM
|
||||
for_sequence_class = Gemma2ForSequenceClassification
|
||||
for_token_class = Gemma2ForTokenClassification
|
||||
|
||||
|
||||
@require_torch
|
||||
class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Gemma2Model, Gemma2ForCausalLM, Gemma2ForSequenceClassification, Gemma2ForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Gemma2Model,
|
||||
"text-classification": Gemma2ForSequenceClassification,
|
||||
"token-classification": Gemma2ForTokenClassification,
|
||||
"text-generation": Gemma2ForCausalLM,
|
||||
"zero-shot": Gemma2ForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
_is_stateful = True
|
||||
model_split_percents = [0.5, 0.6]
|
||||
_torch_compile_test_ckpt = "google/gemma-2-9b"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Gemma2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37)
|
||||
|
||||
@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2's outputs are expected to be different")
|
||||
def test_eager_matches_sdpa_inference(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Gemma2IntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_bf16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
"<bos>Hello I am doing a project for a class and I am trying to use the <code><a-image></code>",
|
||||
"<pad><pad><bos>Hi today. So, I'm going to show you how to do a problem from the textbook. So",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
"<bos>Hello I am doing a project on the effect of the temperature on the rate of a reaction. I am using a ",
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the 1000-4000-",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_longform_generate_multi_batch_cond_prev(self):
|
||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
|
||||
model.eval()
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_features=input_dict["input_features"],
|
||||
decoder_position_ids=position_ids,
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
input_features=input_dict["input_features"],
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
decoder_position_ids=position_ids_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing greedily-chosen tokens:
|
||||
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
|
||||
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
@slow
|
||||
def test_tiny_static_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
assert (eager_generated_ids == static_generated_ids).all()
|
||||
|
||||
# check the compiled graph can be re-used and that the cache is correctly reset
|
||||
# reverse the ordering of the input features
|
||||
permutation_idx = (
|
||||
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
|
||||
)
|
||||
input_features = input_features[permutation_idx, ...]
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
config=config, input_ids=inputs_dict["input_ids"]
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_generate_without_input_ids(self):
|
||||
# generate only works with input ids for whisper
|
||||
|
@ -398,7 +398,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
||||
else:
|
||||
with self.assertRaises((ValueError, AssertionError)):
|
||||
outputs = text_generator("")
|
||||
outputs = text_generator("", add_special_tokens=False)
|
||||
|
||||
if text_generator.framework == "tf":
|
||||
# TF generation does not support max_new_tokens, and it's impossible
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user