mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-06 05:24:35 +08:00
Compare commits
5 Commits
add_chat_t
...
mllama_int
| Author | SHA1 | Date | |
|---|---|---|---|
| da767ce0cb | |||
| 49f0f18f1e | |||
| e0bcc4a10f | |||
| f722aae48f | |||
| f202533977 |
@ -242,7 +242,7 @@ pipeline(
|
||||
|
||||
- This library is not a modular toolbox of building blocks for neural nets. The code in the model files is not refactored with additional abstractions on purpose, so that researchers can quickly iterate on each of the models without diving into additional abstractions/files.
|
||||
- The training API is optimized to work with PyTorch models provided by Transformers. For generic machine learning loops, you should use another library like [Accelerate](https://huggingface.co/docs/accelerate).
|
||||
- The [example scripts](https://github.com/huggingface/transformers/tree/main/examples) are only *examples*. They may not necessarily work out-of-the-box on your specific use case and you'll need to adapt the code for it to work.
|
||||
- The [example scripts]((https://github.com/huggingface/transformers/tree/main/examples)) are only *examples*. They may not necessarily work out-of-the-box on your specific use case and you'll need to adapt the code for it to work.
|
||||
|
||||
## 100 projects using Transformers
|
||||
|
||||
|
||||
@ -13,11 +13,11 @@
|
||||
|
||||
في هذا الدليل، سنستعرض التقنيات الفعالة لتُحسِّن من كفاءة نشر نماذج اللغة الكبيرة:
|
||||
|
||||
1. سنتناول تقنية "دقة أقل" التي أثبتت الأبحاث فعاليتها في تحقيق مزايا حسابية دون التأثير بشكل ملحوظ على أداء النموذج عن طريق العمل بدقة رقمية أقل [8 بت و4 بت](/main_classes/quantization).
|
||||
1. سنتناول تقنية "دقة أقل" التي أثبتت الأبحاث فعاليتها في تحقيق مزايا حسابية دون التأثير بشكل ملحوظ على أداء النموذج عن طريق العمل بدقة رقمية أقل [8 بت و4 بت](/main_classes/quantization.md).
|
||||
|
||||
2. **اFlash Attention:** إن Flash Attention وهي نسخة مُعدَّلة من خوارزمية الانتباه التي لا توفر فقط نهجًا أكثر كفاءة في استخدام الذاكرة، ولكنها تحقق أيضًا كفاءة متزايدة بسبب الاستخدام الأمثل لذاكرة GPU.
|
||||
|
||||
3. **الابتكارات المعمارية:** حيث تم اقتراح هياكل متخصصة تسمح باستدلال أكثر فعالية نظرًا لأن نماذج اللغة الكبيرة يتم نشرها دائمًا بنفس الطريقة أثناء عملية الاستدلال، أي توليد النص التنبؤي التلقائي مع سياق الإدخال الطويل، فقد تم اقتراح بنيات نموذج متخصصة تسمح بالاستدلال الأكثر كفاءة. أهم تقدم في بنيات النماذج هنا هو [عذر](https://huggingface.co/papers/2108.12409)، [الترميز الدوار](https://huggingface.co/papers/2104.09864)، [الاهتمام متعدد الاستعلامات (MQA)](https://huggingface.co/papers/1911.02150) و [مجموعة الانتباه بالاستعلام (GQA)](https://huggingface.co/papers/2305.13245).
|
||||
3. **الابتكارات المعمارية:** حيث تم اقتراح هياكل متخصصة تسمح باستدلال أكثر فعالية نظرًا لأن نماذج اللغة الكبيرة يتم نشرها دائمًا بنفس الطريقة أثناء عملية الاستدلال، أي توليد النص التنبؤي التلقائي مع سياق الإدخال الطويل، فقد تم اقتراح بنيات نموذج متخصصة تسمح بالاستدلال الأكثر كفاءة. أهم تقدم في بنيات النماذج هنا هو [عذر](https://huggingface.co/papers/2108.12409)، [الترميز الدوار](https://huggingface.co/papers/2104.09864)، [الاهتمام متعدد الاستعلامات (MQA)](https://huggingface.co/papers/1911.02150) و [مجموعة الانتباه بالاستعلام (GQA)]((https://huggingface.co/papers/2305.13245)).
|
||||
|
||||
على مدار هذا الدليل، سنقدم تحليلًا للتوليد التنبؤي التلقائي من منظور المُوتِّرات. نتعمق في مزايا وعيوب استخدام دقة أقل، ونقدم استكشافًا شاملاً لخوارزميات الانتباه الأحدث، ونناقش بنيات نماذج نماذج اللغة الكبيرة المحسنة. سندعم الشرح بأمثلة عملية تُبرِز كل تحسين على حدة.
|
||||
|
||||
|
||||
@ -607,8 +607,6 @@
|
||||
title: OLMoE
|
||||
- local: model_doc/open-llama
|
||||
title: Open-Llama
|
||||
- local: model_doc/openai_moe
|
||||
title: OpenAIMoe
|
||||
- local: model_doc/opt
|
||||
title: OPT
|
||||
- local: model_doc/pegasus
|
||||
@ -973,8 +971,6 @@
|
||||
title: CLIPSeg
|
||||
- local: model_doc/clvp
|
||||
title: CLVP
|
||||
- local: model_doc/cohere2_vision
|
||||
title: Cohere2Vision
|
||||
- local: model_doc/colpali
|
||||
title: ColPali
|
||||
- local: model_doc/colqwen2
|
||||
@ -1053,8 +1049,6 @@
|
||||
title: Mistral3
|
||||
- local: model_doc/mllama
|
||||
title: mllama
|
||||
- local: model_doc/mm-grounding-dino
|
||||
title: MM Grounding DINO
|
||||
- local: model_doc/nougat
|
||||
title: Nougat
|
||||
- local: model_doc/omdet-turbo
|
||||
|
||||
@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co
|
||||
|
||||
## chat CLI
|
||||
|
||||
After you've [installed Transformers](./installation), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||
|
||||
```bash
|
||||
transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
@ -158,4 +158,4 @@ The easiest solution for improving generation speed is to either quantize a mode
|
||||
You can also try techniques like [speculative decoding](./generation_strategies#speculative-decoding), where a smaller model generates candidate tokens that are verified by the larger model. If the candidate tokens are correct, the larger model can generate more than one token per `forward` pass. This significantly alleviates the bandwidth bottleneck and improves generation speed.
|
||||
|
||||
> [!TIP]
|
||||
> Parameters may not be active for every generated token in MoE models such as [Mixtral](./model_doc/mixtral), [Qwen2MoE](./model_doc/qwen2_moe), and [DBRX](./model_doc/dbrx). As a result, MoE models generally have much lower memory bandwidth requirements and can be faster than a regular LLM of the same size. However, techniques like speculative decoding are ineffective with MoE models because parameters become activated with each new speculated token.
|
||||
> Parameters may not be active for every generated token in MoE models such as [Mixtral](./model_doc/mixtral), [Qwen2MoE](./model_doc/qwen2_moe.md), and [DBRX](./model_doc/dbrx). As a result, MoE models generally have much lower memory bandwidth requirements and can be faster than a regular LLM of the same size. However, techniques like speculative decoding are ineffective with MoE models because parameters become activated with each new speculated token.
|
||||
|
||||
@ -148,9 +148,9 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
| Option name | Type | Simplified description |
|
||||
|---|---|---|
|
||||
| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. |
|
||||
| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies) for more information. |
|
||||
| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. |
|
||||
| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. |
|
||||
| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies) for more information. |
|
||||
| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. |
|
||||
| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. |
|
||||
| `eos_token_id` | `list[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. |
|
||||
|
||||
|
||||
@ -23,11 +23,11 @@ The crux of these challenges lies in augmenting the computational and memory cap
|
||||
|
||||
In this guide, we will go over the effective techniques for efficient LLM deployment:
|
||||
|
||||
1. **Lower Precision:** Research has shown that operating at reduced numerical precision, namely [8-bit and 4-bit](./main_classes/quantization) can achieve computational advantages without a considerable decline in model performance.
|
||||
1. **Lower Precision:** Research has shown that operating at reduced numerical precision, namely [8-bit and 4-bit](./main_classes/quantization.md) can achieve computational advantages without a considerable decline in model performance.
|
||||
|
||||
2. **Flash Attention:** Flash Attention is a variation of the attention algorithm that not only provides a more memory-efficient approach but also realizes increased efficiency due to optimized GPU memory utilization.
|
||||
|
||||
3. **Architectural Innovations:** Considering that LLMs are always deployed in the same way during inference, namely autoregressive text generation with a long input context, specialized model architectures have been proposed that allow for more efficient inference. The most important advancement in model architectures hereby are [Alibi](https://huggingface.co/papers/2108.12409), [Rotary embeddings](https://huggingface.co/papers/2104.09864), [Multi-Query Attention (MQA)](https://huggingface.co/papers/1911.02150) and [Grouped-Query-Attention (GQA)](https://huggingface.co/papers/2305.13245).
|
||||
3. **Architectural Innovations:** Considering that LLMs are always deployed in the same way during inference, namely autoregressive text generation with a long input context, specialized model architectures have been proposed that allow for more efficient inference. The most important advancement in model architectures hereby are [Alibi](https://huggingface.co/papers/2108.12409), [Rotary embeddings](https://huggingface.co/papers/2104.09864), [Multi-Query Attention (MQA)](https://huggingface.co/papers/1911.02150) and [Grouped-Query-Attention (GQA)]((https://huggingface.co/papers/2305.13245)).
|
||||
|
||||
Throughout this guide, we will offer an analysis of auto-regressive generation from a tensor's perspective. We delve into the pros and cons of adopting lower precision, provide a comprehensive exploration of the latest attention algorithms, and discuss improved LLM architectures. While doing so, we run practical examples showcasing each of the feature improvements.
|
||||
|
||||
|
||||
@ -14,81 +14,49 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# BARThez
|
||||
|
||||
[BARThez](https://huggingface.co/papers/2010.12321) is a [BART](./bart) model designed for French language tasks. Unlike existing French BERT models, BARThez includes a pretrained encoder-decoder, allowing it to generate text as well. This model is also available as a multilingual variant, mBARThez, by continuing pretraining multilingual BART on a French corpus.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
</div>
|
||||
|
||||
You can find all of the original BARThez checkpoints under the [BARThez](https://huggingface.co/collections/dascim/barthez-670920b569a07aa53e3b6887) collection.
|
||||
## Overview
|
||||
|
||||
> [!TIP]
|
||||
> This model was contributed by [moussakam](https://huggingface.co/moussakam).
|
||||
> Refer to the [BART](./bart) docs for more usage examples.
|
||||
The BARThez model was proposed in [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://huggingface.co/papers/2010.12321) by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis on 23 Oct,
|
||||
2020.
|
||||
|
||||
The abstract of the paper:
|
||||
|
||||
|
||||
The example below demonstrates how to predict the `<mask>` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
*Inductive transfer learning, enabled by self-supervised learning, have taken the entire Natural Language Processing
|
||||
(NLP) field by storm, with models such as BERT and BART setting new state of the art on countless natural language
|
||||
understanding tasks. While there are some notable exceptions, most of the available models and research have been
|
||||
conducted for the English language. In this work, we introduce BARThez, the first BART model for the French language
|
||||
(to the best of our knowledge). BARThez was pretrained on a very large monolingual French corpus from past research
|
||||
that we adapted to suit BART's perturbation schemes. Unlike already existing BERT-based French language models such as
|
||||
CamemBERT and FlauBERT, BARThez is particularly well-suited for generative tasks, since not only its encoder but also
|
||||
its decoder is pretrained. In addition to discriminative tasks from the FLUE benchmark, we evaluate BARThez on a novel
|
||||
summarization dataset, OrangeSum, that we release with this paper. We also continue the pretraining of an already
|
||||
pretrained multilingual BART on BARThez's corpus, and we show that the resulting model, which we call mBARTHez,
|
||||
provides a significant boost over vanilla BARThez, and is on par with or outperforms CamemBERT and FlauBERT.*
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
This model was contributed by [moussakam](https://huggingface.co/moussakam). The Authors' code can be found [here](https://github.com/moussaKam/BARThez).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
<Tip>
|
||||
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="moussaKam/barthez",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Les plantes produisent <mask> grâce à un processus appelé photosynthèse.")
|
||||
```
|
||||
BARThez implementation is the same as BART, except for tokenization. Refer to [BART documentation](bart) for information on
|
||||
configuration classes and their parameters. BARThez-specific tokenizers are documented below.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
## Resources
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"moussaKam/barthez",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"moussaKam/barthez",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
inputs = tokenizer("Les plantes produisent <mask> grâce à un processus appelé photosynthèse.", return_tensors="pt").to("cuda")
|
||||
- BARThez can be fine-tuned on sequence-to-sequence tasks in a similar way as BART, check:
|
||||
[examples/pytorch/summarization/](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization/README.md).
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```bash
|
||||
echo -e "Les plantes produisent <mask> grâce à un processus appelé photosynthèse." | transformers run --task fill-mask --model moussaKam/barthez --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## BarthezTokenizer
|
||||
|
||||
|
||||
@ -1,115 +1,43 @@
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
|
||||
</div>
|
||||
# Cohere
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
[C4AI Command R7B](https://cohere.com/blog/command-r7b) is an open weights research release of a 7B billion parameter model developed by Cohere and Cohere For AI. It has advanced capabilities optimized for various use cases, including reasoning, summarization, question answering, and code. The model is trained to perform sophisticated tasks including Retrieval Augmented Generation (RAG) and tool use. The model also has powerful agentic capabilities that can use and combine multiple tools over multiple steps to accomplish more difficult tasks. It obtains top performance on enterprise-relevant code use cases. C4AI Command R7B is a multilingual model trained on 23 languages.
|
||||
|
||||
# Cohere2
|
||||
The model features three layers with sliding window attention (window size 4096) and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.
|
||||
|
||||
[Cohere Command R7B](https://cohere.com/blog/command-r7b) is an open weights research release of a 7B billion parameter model. It is a multilingual model trained on 23 languages and has a context window of 128k. The model features three layers with sliding window attention and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.
|
||||
The model has been trained on 23 languages: English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Arabic, Chinese, Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, and Persian.
|
||||
|
||||
This model is optimized for speed, cost-performance, and compute resources.
|
||||
|
||||
You can find all the original Command-R checkpoints under the [Command Models](https://huggingface.co/collections/CohereForAI/command-models-67652b401665205e17b192ad) collection.
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Cohere models in the right sidebar for more examples of how to apply Cohere to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
## Usage tips
|
||||
The model and tokenizer can be loaded via:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="CohereLabs/c4ai-command-r7b-12-2024",
|
||||
torch_dtype=torch.float16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, can you please help me book a hotel in Japan?"},
|
||||
]
|
||||
pipeline(messages)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
# pip install transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereLabs/c4ai-command-r7b-12-2024")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"CohereLabs/c4ai-command-r7b-12-2024",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "Hello, can you please help me book a hotel in Japan?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
# Format message with the command-r chat template
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat CohereLabs/c4ai-command-r7b-12-2024 --torch_dtype auto --attn_implementation flash_attention_2
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview.md) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereLabs/c4ai-command-r7b-12-2024")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"CohereLabs/c4ai-command-r7b-12-2024",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "Hello, can you please help me book a hotel in Japan?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
```
|
||||
|
||||
## Cohere2Config
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
# Command A Vision
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Tensor parallelism" src="https://img.shields.io/badge/Tensor%20parallelism-06b6d4?style=flat&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
Command A Vision is a state-of-the-art multimodal model designed to seamlessly integrate visual and textual information for a wide range of applications. By combining advanced computer vision techniques with natural language processing capabilities, Command A Vision enables users to analyze, understand, and generate insights from both visual and textual data.
|
||||
|
||||
The model excels at tasks including image captioning, visual question answering, document understanding, and chart understanding. This makes it a versatile tool for AI practitioners. Its ability to process complex visual and textual inputs makes it useful in settings where text-only representations are imprecise or unavailable, like real-world image understanding and graphics-heavy document processing.
|
||||
|
||||
Command A Vision is built upon a robust architecture that leverages the latest advancements in VLMs. It's highly performant and efficient, even when dealing with large-scale datasets. The model's flexibility makes it suitable for a wide range of use cases, from content moderation and image search to medical imaging analysis and robotics.
|
||||
|
||||
## Usage tips
|
||||
|
||||
The model and image processor can be loaded as follows:
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
model_id = "CohereLabs/command-a-vision-07-2025"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# Format message with the Command-A-Vision chat template
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "what is in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
gen_tokens = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=300,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
print(
|
||||
processor.tokenizer.decode(
|
||||
gen_tokens[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(model="CohereLabs/command-a-vision-07-2025", task="image-text-to-text", device_map="auto")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://media.istockphoto.com/id/458012057/photo/istanbul-turkey.jpg?s=612x612&w=0&k=20&c=qogAOVvkpfUyqLUMr_XJQyq-HkACXyYUSZbKhBlPrxo=",
|
||||
},
|
||||
{"type": "text", "text": "Where was this taken ?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
outputs = pipe(text=messages, max_new_tokens=300, return_full_text=False)
|
||||
|
||||
print(outputs)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Cohere2VisionConfig
|
||||
|
||||
[[autodoc]] Cohere2VisionConfig
|
||||
|
||||
## Cohere2VisionForConditionalGeneration
|
||||
|
||||
[[autodoc]] Cohere2VisionForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Cohere2VisionModel
|
||||
|
||||
[[autodoc]] Cohere2VisionModel
|
||||
- forward
|
||||
|
||||
## Cohere2VisionImageProcessorFast
|
||||
|
||||
[[autodoc]] Cohere2VisionImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Cohere2VisionProcessor
|
||||
|
||||
[[autodoc]] Cohere2VisionProcessor
|
||||
@ -95,7 +95,7 @@ images = [
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to int4.
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
@ -99,7 +99,7 @@ images = [
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to int4.
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
@ -21,7 +21,7 @@ rendered properly in your Markdown viewer.
|
||||
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
|
||||
|
||||
**Model Architecture:**
|
||||
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
|
||||
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
|
||||
|
||||
The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
|
||||
|
||||
|
||||
@ -209,10 +209,6 @@ model = DeepseekVLForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] DeepseekVLImageProcessor
|
||||
|
||||
## DeepseekVLImageProcessorFast
|
||||
|
||||
[[autodoc]] DeepseekVLImageProcessorFast
|
||||
|
||||
## DeepseekVLModel
|
||||
|
||||
[[autodoc]] DeepseekVLModel
|
||||
|
||||
@ -208,10 +208,6 @@ model = DeepseekVLHybridForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] DeepseekVLHybridImageProcessor
|
||||
|
||||
## DeepseekVLHybridImageProcessorFast
|
||||
|
||||
[[autodoc]] DeepseekVLHybridImageProcessorFast
|
||||
|
||||
## DeepseekVLHybridModel
|
||||
|
||||
[[autodoc]] DeepseekVLHybridModel
|
||||
|
||||
@ -26,14 +26,14 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
## Overview
|
||||
|
||||
Dia is an open-source text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
|
||||
It can generate highly realistic dialogue from transcript including non-verbal communications such as laughter and coughing.
|
||||
Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
|
||||
It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing.
|
||||
Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning).
|
||||
|
||||
**Model Architecture:**
|
||||
Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as
|
||||
rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while
|
||||
for the audio portion (decoder), a pretrained codec model [DAC](./dac) is used - DAC encodes speech into discrete codebook
|
||||
for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook
|
||||
tokens and decodes them back into audio.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
@ -27,7 +27,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
ERNIE (Enhanced Representation through kNowledge IntEgration) is designed to learn language representation enhanced by knowledge masking strategies, which includes entity-level masking and phrase-level masking.
|
||||
|
||||
Other ERNIE models released by baidu can be found at [Ernie 4.5](./ernie4_5), and [Ernie 4.5 MoE](./ernie4_5_moe).
|
||||
Other ERNIE models released by baidu can be found at [Ernie 4.5](./ernie4_5.md), and [Ernie 4.5 MoE](./ernie4_5_moe.md).
|
||||
|
||||
> [!TIP]
|
||||
> This model was contributed by [nghuyong](https://huggingface.co/nghuyong), and the official code can be found in [PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) (in PaddlePaddle).
|
||||
|
||||
@ -29,9 +29,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
The Ernie 4.5 model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
|
||||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
|
||||
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama) at its core.
|
||||
model without mixture of experts (moe) with 0.3B parameters in total. It uses the standard [Llama](./llama.md) at its core.
|
||||
|
||||
Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe).
|
||||
Other models from the family can be found at [Ernie 4.5 Moe](./ernie4_5_moe.md).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
|
||||
|
||||
@ -30,10 +30,10 @@ rendered properly in your Markdown viewer.
|
||||
The Ernie 4.5 Moe model was released in the [Ernie 4.5 Model Family](https://ernie.baidu.com/blog/posts/ernie4.5/) release by baidu.
|
||||
This family of models contains multiple different architectures and model sizes. This model in specific targets the base text
|
||||
model with mixture of experts (moe) - one with 21B total, 3B active parameters and another one with 300B total, 47B active parameters.
|
||||
It uses the standard [Llama](./llama) at its core combined with a specialized MoE based on [Mixtral](./mixtral) with additional shared
|
||||
It uses the standard [Llama](./llama.md) at its core combined with a specialized MoE based on [Mixtral](./mixtral.md) with additional shared
|
||||
experts.
|
||||
|
||||
Other models from the family can be found at [Ernie 4.5](./ernie4_5).
|
||||
Other models from the family can be found at [Ernie 4.5](./ernie4_5.md).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://ernie.baidu.com/blog/posts/ernie4.5/overview.png"/>
|
||||
|
||||
@ -30,7 +30,7 @@ Gemma3n is a multimodal model with pretrained and instruction-tuned variants, av
|
||||
large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in
|
||||
this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented Residual Layer][laurel] (LAuReL),
|
||||
[MatFormer][matformer], Per-Layer Embeddings (PLE), [Activation Sparsity with Statistical Top-k][spark-transformer], and KV cache sharing. The language model uses
|
||||
a similar attention pattern to [Gemma 3](./gemma3) with alternating 4 local sliding window self-attention layers for
|
||||
a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
|
||||
every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
|
||||
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a newly
|
||||
trained audio encoder based on the [Universal Speech Model][usm] (USM) architecture.
|
||||
|
||||
@ -169,9 +169,9 @@ model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
|
||||
## Shrinking down Idefics2 using quantization
|
||||
|
||||
As the Idefics2 model has 8 billion parameters, that would require about 16GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization). If the model is quantized to 4 bits (or half a byte per parameter), that requires only about 3.5GB of RAM.
|
||||
As the Idefics2 model has 8 billion parameters, that would require about 16GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), that requires only about 3.5GB of RAM.
|
||||
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. One can change the code snippet above with the changes below. We'll leverage the BitsAndyBytes quantization (but refer to [this page](../quantization) for other quantization methods):
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. One can change the code snippet above with the changes below. We'll leverage the BitsAndyBytes quantization (but refer to [this page](../quantization.md) for other quantization methods):
|
||||
|
||||
```diff
|
||||
+ from transformers import BitsAndBytesConfig
|
||||
@ -193,7 +193,7 @@ model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Idefics2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
- A notebook on how to fine-tune Idefics2 on a custom dataset using the [Trainer](../main_classes/trainer) can be found [here](https://colab.research.google.com/drive/1NtcTgRbSBKN7pYD3Vdx1j9m8pt3fhFDB?usp=sharing). It supports both full fine-tuning as well as (quantized) LoRa.
|
||||
- A notebook on how to fine-tune Idefics2 on a custom dataset using the [Trainer](../main_classes/trainer.md) can be found [here](https://colab.research.google.com/drive/1NtcTgRbSBKN7pYD3Vdx1j9m8pt3fhFDB?usp=sharing). It supports both full fine-tuning as well as (quantized) LoRa.
|
||||
- A script regarding how to fine-tune Idefics2 using the TRL library can be found [here](https://gist.github.com/edbeeching/228652fc6c2b29a1641be5a5778223cb).
|
||||
- Demo notebook regarding fine-tuning Idefics2 for JSON extraction use cases can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Idefics2). 🌎
|
||||
|
||||
|
||||
@ -44,11 +44,11 @@ Here is the example of visual understanding with a single image.
|
||||
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import JanusForConditionalGeneration, JanusProcessor
|
||||
from transformers import JanusForConditionalGeneration, JanusProcessor
|
||||
|
||||
model_id = "deepseek-community/Janus-Pro-1B"
|
||||
# Prepare Input for generation.
|
||||
@ -64,7 +64,7 @@ messages = [
|
||||
|
||||
# Set generation mode to `text` to perform text generation.
|
||||
processor = JanusProcessor.from_pretrained(model_id)
|
||||
model = JanusForConditionalGeneration.from_pretrained(model_id,
|
||||
model = JanusForConditionalGeneration.from_pretrained(model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto")
|
||||
|
||||
@ -209,10 +209,6 @@ for i, image in enumerate(images['pixel_values']):
|
||||
|
||||
[[autodoc]] JanusImageProcessor
|
||||
|
||||
## JanusImageProcessorFast
|
||||
|
||||
[[autodoc]] JanusImageProcessorFast
|
||||
|
||||
## JanusVisionModel
|
||||
|
||||
[[autodoc]] JanusVisionModel
|
||||
|
||||
@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
```py
|
||||
# Easy visualization using the built-in plotting method
|
||||
processor.visualize_keypoint_matching(images, processed_outputs)
|
||||
processor.plot_keypoint_matching(images, processed_outputs)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
- plot_keypoint_matching
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
@ -33,7 +33,7 @@ alt="drawing" width="600"/>
|
||||
|
||||
<small> MGP-STR architecture. Taken from the <a href="https://huggingface.co/papers/2209.03592">original paper</a>. </small>
|
||||
|
||||
MGP-STR is trained on two synthetic datasets [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ) and [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST) without fine-tuning on other datasets. It achieves state-of-the-art results on six standard Latin scene text benchmarks, including 3 regular text datasets (IC13, SVT, IIIT) and 3 irregular ones (IC15, SVTP, CUTE).
|
||||
MGP-STR is trained on two synthetic datasets [MJSynth]((http://www.robots.ox.ac.uk/~vgg/data/text/)) (MJ) and [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST) without fine-tuning on other datasets. It achieves state-of-the-art results on six standard Latin scene text benchmarks, including 3 regular text datasets (IC13, SVT, IIIT) and 3 irregular ones (IC15, SVTP, CUTE).
|
||||
This model was contributed by [yuekun](https://huggingface.co/yuekun). The original code can be found [here](https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/OCR/MGP-STR).
|
||||
|
||||
## Inference example
|
||||
|
||||
@ -30,7 +30,7 @@ The abstract from the paper is the following:
|
||||
|
||||
*We introduce Moshi, a speech-text foundation model and full-duplex spoken dialogue framework. Current systems for spoken dialogue rely on pipelines of independent components, namely voice activity detection, speech recognition, textual dialogue and text-to-speech. Such frameworks cannot emulate the experience of real conversations. First, their complexity induces a latency of several seconds between interactions. Second, text being the intermediate modality for dialogue, non-linguistic information that modifies meaning— such as emotion or non-speech sounds— is lost in the interaction. Finally, they rely on a segmentation into speaker turns, which does not take into account overlapping speech, interruptions and interjections. Moshi solves these independent issues altogether by casting spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics. We moreover extend the hierarchical semantic-to-acoustic token generation of previous work to first predict time-aligned text tokens as a prefix to audio tokens. Not only this “Inner Monologue” method significantly improves the linguistic quality of generated speech, but we also illustrate how it can provide streaming speech recognition and text-to-speech. Our resulting model is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice, and is available at github.com/kyutai-labs/moshi.*
|
||||
|
||||
Its architecture is based on [Encodec](./encodec) with several major differences:
|
||||
Its architecture is based on [Encodec](model_doc/encodec) with several major differences:
|
||||
* it uses a much lower frame-rate.
|
||||
* it uses additional transformers for encoding and decoding for better latent contextualization
|
||||
* it uses a different quantization scheme: one codebook is dedicated to semantic projection.
|
||||
|
||||
@ -115,9 +115,9 @@ The Flash Attention-2 model uses also a more memory efficient cache slicing mech
|
||||
|
||||
## Shrinking down MiniMax using quantization
|
||||
|
||||
As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.
|
||||
As the MiniMax model has 456 billion parameters, that would require about 912GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), about 228 GB of RAM is required.
|
||||
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization) for alternative quantization methods):
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
@ -146,9 +146,9 @@ The Flash Attention-2 model uses also a more memory efficient cache slicing mech
|
||||
|
||||
## Shrinking down Mixtral using quantization
|
||||
|
||||
As the Mixtral model has 45 billion parameters, that would require about 90GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization). If the model is quantized to 4 bits (or half a byte per parameter), a single A100 with 40GB of RAM is enough to fit the entire model, as in that case only about 27 GB of RAM is required.
|
||||
As the Mixtral model has 45 billion parameters, that would require about 90GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter), a single A100 with 40GB of RAM is enough to fit the entire model, as in that case only about 27 GB of RAM is required.
|
||||
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization) for alternative quantization methods):
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the bitsandbytes quantization library (but refer to [this page](../quantization.md) for alternative quantization methods):
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
@ -1,124 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# MM Grounding DINO
|
||||
|
||||
[MM Grounding DINO](https://arxiv.org/abs/2401.02361) model was proposed in [An Open and Comprehensive Pipeline for Unified Object Grounding and Detection](https://arxiv.org/abs/2401.02361) by Xiangyu Zhao, Yicheng Chen, Shilin Xu, Xiangtai Li, Xinjiang Wang, Yining Li, Haian Huang>.
|
||||
|
||||
MM Grounding DINO improves upon the [Grounding DINO](https://huggingface.co/docs/transformers/model_doc/grounding-dino) by improving the contrastive class head and removing the parameter sharing in the decoder, improving zero-shot detection performance on both COCO (50.6(+2.2) AP) and LVIS (31.9(+11.8) val AP and 41.4(+12.6) minival AP).
|
||||
|
||||
You can find all the original MM Grounding DINO checkpoints under the [MM Grounding DINO](https://huggingface.co/collections/openmmlab-community/mm-grounding-dino-688cbde05b814c4e2832f9df) collection. This model also supports LLMDet inference. You can find LLMDet checkpoints under the [LLMDet](https://huggingface.co/collections/iSEE-Laboratory/llmdet-688475906dc235d5f1dc678e) collection.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the MM Grounding DINO models in the right sidebar for more examples of how to apply MM Grounding DINO to different MM Grounding DINO tasks.
|
||||
|
||||
The example below demonstrates how to generate text based on an image with the [`AutoModelForZeroShotObjectDetection`] class.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
|
||||
# Prepare processor and model
|
||||
model_id = "openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
# Prepare inputs
|
||||
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = load_image(image_url)
|
||||
text_labels = [["a cat", "a remote control"]]
|
||||
inputs = processor(images=image, text=text_labels, return_tensors="pt").to(device)
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Postprocess outputs
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
threshold=0.4,
|
||||
target_sizes=[(image.height, image.width)]
|
||||
)
|
||||
|
||||
# Retrieve the first image result
|
||||
result = results[0]
|
||||
for box, score, labels in zip(result["boxes"], result["scores"], result["labels"]):
|
||||
box = [round(x, 2) for x in box.tolist()]
|
||||
print(f"Detected {labels} with confidence {round(score.item(), 3)} at location {box}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Notes
|
||||
|
||||
- Here's a table of models and their object detection performance results on COCO (results from [official repo](https://github.com/open-mmlab/mmdetection/blob/main/configs/mm_grounding_dino/README.md)):
|
||||
|
||||
| Model | Backbone | Pre-Train Data | Style | COCO mAP |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------ | -------- | ------------------------ | --------- | ---------- |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg) | Swin-T | O365,GoldG | Zero-shot | 50.4(+2.3) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_grit](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_grit) | Swin-T | O365,GoldG,GRIT | Zero-shot | 50.5(+2.1) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det) | Swin-T | O365,GoldG,V3Det | Zero-shot | 50.6(+2.2) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_grit_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_grit_v3det) | Swin-T | O365,GoldG,GRIT,V3Det | Zero-shot | 50.4(+2.0) |
|
||||
| [mm_grounding_dino_base_o365v1_goldg_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_base_o365v1_goldg_v3det) | Swin-B | O365,GoldG,V3Det | Zero-shot | 52.5 |
|
||||
| [mm_grounding_dino_base_all](https://huggingface.co/openmmlab-community/mm_grounding_dino_base_all) | Swin-B | O365,ALL | - | 59.5 |
|
||||
| [mm_grounding_dino_large_o365v2_oiv6_goldg](https://huggingface.co/openmmlab-community/mm_grounding_dino_large_o365v2_oiv6_goldg) | Swin-L | O365V2,OpenImageV6,GoldG | Zero-shot | 53.0 |
|
||||
| [mm_grounding_dino_large_all](https://huggingface.co/openmmlab-community/mm_grounding_dino_large_all) | Swin-L | O365V2,OpenImageV6,ALL | - | 60.3 |
|
||||
|
||||
- Here's a table of MM Grounding DINO tiny models and their object detection performance on LVIS (results from [official repo](https://github.com/open-mmlab/mmdetection/blob/main/configs/mm_grounding_dino/README.md)):
|
||||
|
||||
| Model | Pre-Train Data | MiniVal APr | MiniVal APc | MiniVal APf | MiniVal AP | Val1.0 APr | Val1.0 APc | Val1.0 APf | Val1.0 AP |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------ | --------------------- | ----------- | ----------- | ----------- | ----------- | ---------- | ---------- | ---------- | ----------- |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg) | O365,GoldG | 28.1 | 30.2 | 42.0 | 35.7(+6.9) | 17.1 | 22.4 | 36.5 | 27.0(+6.9) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_grit](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_grit) | O365,GoldG,GRIT | 26.6 | 32.4 | 41.8 | 36.5(+7.7) | 17.3 | 22.6 | 36.4 | 27.1(+7.0) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_v3det) | O365,GoldG,V3Det | 33.0 | 36.0 | 45.9 | 40.5(+11.7) | 21.5 | 25.5 | 40.2 | 30.6(+10.5) |
|
||||
| [mm_grounding_dino_tiny_o365v1_goldg_grit_v3det](https://huggingface.co/openmmlab-community/mm_grounding_dino_tiny_o365v1_goldg_grit_v3det) | O365,GoldG,GRIT,V3Det | 34.2 | 37.4 | 46.2 | 41.4(+12.6) | 23.6 | 27.6 | 40.5 | 31.9(+11.8) |
|
||||
|
||||
|
||||
- This implementation also supports inference for [LLMDet](https://github.com/iSEE-Laboratory/LLMDet). Here's a table of LLMDet models and their performance on LVIS (results from [official repo](https://github.com/iSEE-Laboratory/LLMDet)):
|
||||
|
||||
| Model | Pre-Train Data | MiniVal APr | MiniVal APc | MiniVal APf | MiniVal AP | Val1.0 APr | Val1.0 APc | Val1.0 APf | Val1.0 AP |
|
||||
| --------------------------------------------------------- | -------------------------------------------- | ------------ | ----------- | ----------- | ----------- | ---------- | ---------- | ---------- | ----------- |
|
||||
| [llmdet_tiny](https://huggingface.co/iSEE-Laboratory/llmdet_tiny) | (O365,GoldG,GRIT,V3Det) + GroundingCap-1M | 44.7 | 37.3 | 39.5 | 50.7 | 34.9 | 26.0 | 30.1 | 44.3 |
|
||||
| [llmdet_base](https://huggingface.co/iSEE-Laboratory/llmdet_base) | (O365,GoldG,V3Det) + GroundingCap-1M | 48.3 | 40.8 | 43.1 | 54.3 | 38.5 | 28.2 | 34.3 | 47.8 |
|
||||
| [llmdet_large](https://huggingface.co/iSEE-Laboratory/llmdet_large) | (O365V2,OpenImageV6,GoldG) + GroundingCap-1M | 51.1 | 45.1 | 46.1 | 56.6 | 42.0 | 31.6 | 38.8 | 50.2 |
|
||||
|
||||
|
||||
## MMGroundingDinoConfig
|
||||
|
||||
[[autodoc]] MMGroundingDinoConfig
|
||||
|
||||
## MMGroundingDinoModel
|
||||
|
||||
[[autodoc]] MMGroundingDinoModel
|
||||
- forward
|
||||
|
||||
## MMGroundingDinoForObjectDetection
|
||||
|
||||
[[autodoc]] MMGroundingDinoForObjectDetection
|
||||
- forward
|
||||
@ -14,115 +14,54 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# mT5
|
||||
|
||||
[mT5](https://huggingface.co/papers/2010.11934) is a multilingual variant of [T5](./t5), training on 101 languages. It also incorporates a new "accidental translation" technique to prevent the model from incorrectly translating predictions into the wrong language.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
</div>
|
||||
|
||||
You can find all the original [mT5] checkpoints under the [mT5](https://huggingface.co/collections/google/mt5-release-65005f1a520f8d7b4d039509) collection.
|
||||
## Overview
|
||||
|
||||
> [!TIP]
|
||||
> This model was contributed by [patrickvonplaten](https://huggingface.co/patrickvonplaten).
|
||||
>
|
||||
> Click on the mT5 models in the right sidebar for more examples of how to apply mT5 to different language tasks.
|
||||
The mT5 model was presented in [mT5: A massively multilingual pre-trained text-to-text transformer](https://huggingface.co/papers/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya
|
||||
Siddhant, Aditya Barua, Colin Raffel.
|
||||
|
||||
The example below demonstrates how to summarize text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
The abstract from the paper is the following:
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
|
||||
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
|
||||
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We detail
|
||||
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
|
||||
benchmarks. We also describe a simple technique to prevent "accidental translation" in the zero-shot setting, where a
|
||||
generative model chooses to (partially) translate its prediction into the wrong language. All of the code and model
|
||||
checkpoints used in this work are publicly available.*
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
Note: mT5 was only pre-trained on [mC4](https://huggingface.co/datasets/mc4) excluding any supervised training.
|
||||
Therefore, this model has to be fine-tuned before it is usable on a downstream task, unlike the original T5 model.
|
||||
Since mT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
|
||||
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text2text-generation",
|
||||
model="csebuetnlp/mT5_multilingual_XLSum",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("""Plants are remarkable organisms that produce their own food using a method called photosynthesis.
|
||||
This process involves converting sunlight, carbon dioxide, and water into glucose, which provides energy for growth.
|
||||
Plants play a crucial role in sustaining life on Earth by generating oxygen and serving as the foundation of most ecosystems.""")
|
||||
```
|
||||
Google has released the following variants:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
- [google/mt5-small](https://huggingface.co/google/mt5-small)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
- [google/mt5-base](https://huggingface.co/google/mt5-base)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"csebuetnlp/mT5_multilingual_XLSum"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"csebuetnlp/mT5_multilingual_XLSum",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
- [google/mt5-large](https://huggingface.co/google/mt5-large)
|
||||
|
||||
input_text = """Plants are remarkable organisms that produce their own food using a method called photosynthesis.
|
||||
This process involves converting sunlight, carbon dioxide, and water into glucose, which provides energy for growth.
|
||||
Plants play a crucial role in sustaining life on Earth by generating oxygen and serving as the foundation of most ecosystems."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
- [google/mt5-xl](https://huggingface.co/google/mt5-xl)
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
- [google/mt5-xxl](https://huggingface.co/google/mt5-xxl).
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
This model was contributed by [patrickvonplaten](https://huggingface.co/patrickvonplaten). The original code can be
|
||||
found [here](https://github.com/google-research/multilingual-t5).
|
||||
|
||||
```bash
|
||||
echo -e "Plants are remarkable organisms that produce their own food using a method called photosynthesis." | transformers run --task text2text-generation --model csebuetnlp/mT5_multilingual_XLSum --device 0
|
||||
```
|
||||
## Resources
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"csebuetnlp/mT5_multilingual_XLSum",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"csebuetnlp/mT5_multilingual_XLSum"
|
||||
)
|
||||
input_text = """Plants are remarkable organisms that produce their own food using a method called photosynthesis.
|
||||
This process involves converting sunlight, carbon dioxide, and water into glucose, which provides energy for growth.
|
||||
Plants play a crucial role in sustaining life on Earth by generating oxygen and serving as the foundation of most ecosystems."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- mT5 must be fine-tuned for downstream tasks because it was only pretrained on the [mc4](https://huggingface.co/datasets/mc4) dataset.
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
|
||||
## MT5Config
|
||||
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# OpenAIMoE
|
||||
|
||||
## Overview
|
||||
|
||||
The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*<INSERT PAPER ABSTRACT HERE>*
|
||||
|
||||
Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
|
||||
## OpenAIMoeConfig
|
||||
|
||||
[[autodoc]] OpenAIMoeConfig
|
||||
|
||||
## OpenAIMoeModel
|
||||
|
||||
[[autodoc]] OpenAIMoeModel
|
||||
- forward
|
||||
|
||||
## OpenAIMoeForCausalLM
|
||||
|
||||
[[autodoc]] OpenAIMoeForCausalLM
|
||||
- forward
|
||||
@ -38,7 +38,7 @@ This model was contributed by [ajati](https://huggingface.co/ajati), [vijaye12](
|
||||
|
||||
## Usage example
|
||||
|
||||
The code snippet below shows how to randomly initialize a PatchTSMixer model. The model is compatible with the [Trainer API](../trainer).
|
||||
The code snippet below shows how to randomly initialize a PatchTSMixer model. The model is compatible with the [Trainer API](../trainer.md).
|
||||
|
||||
```python
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ rendered properly in your Markdown viewer.
|
||||
# Qwen2MoE
|
||||
|
||||
|
||||
[Qwen2MoE](https://huggingface.co/papers/2407.10671) is a Mixture-of-Experts (MoE) variant of [Qwen2](./qwen2), available as a base model and an aligned chat model. It uses SwiGLU activation, group query attention and a mixture of sliding window attention and full attention. The tokenizer can also be adapted to multiple languages and codes.
|
||||
[Qwen2MoE]((https://huggingface.co/papers/2407.10671) ) is a Mixture-of-Experts (MoE) variant of [Qwen2](./qwen2), available as a base model and an aligned chat model. It uses SwiGLU activation, group query attention and a mixture of sliding window attention and full attention. The tokenizer can also be adapted to multiple languages and codes.
|
||||
|
||||
The MoE architecture uses upcyled models from the dense language models. For example, Qwen1.5-MoE-A2.7B is upcycled from Qwen-1.8B. It has 14.3B parameters but only 2.7B parameters are activated during runtime.
|
||||
|
||||
|
||||
@ -103,11 +103,38 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
|
||||
```
|
||||
|
||||
- Visualize the matches between the images using the built-in plotting functionality.
|
||||
- The example below demonstrates how to visualize matches between two images.
|
||||
|
||||
```py
|
||||
# Easy visualization using the built-in plotting method
|
||||
processor.visualize_keypoint_matching(images, processed_outputs)
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Create side by side image
|
||||
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
|
||||
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
|
||||
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
|
||||
plt.imshow(merged_image)
|
||||
plt.axis("off")
|
||||
|
||||
# Retrieve the keypoints and matches
|
||||
output = processed_outputs[0]
|
||||
keypoints0 = output["keypoints0"]
|
||||
keypoints1 = output["keypoints1"]
|
||||
matching_scores = output["matching_scores"]
|
||||
|
||||
# Plot the matches
|
||||
for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
|
||||
plt.plot(
|
||||
[keypoint0[0], keypoint1[0] + image1.width],
|
||||
[keypoint0[1], keypoint1[1]],
|
||||
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||
alpha=0.9,
|
||||
linewidth=0.5,
|
||||
)
|
||||
plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
|
||||
plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
|
||||
|
||||
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
@ -128,7 +155,6 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
@ -69,11 +69,11 @@ print(tokenizer.decode(outputs[0]))
|
||||
## Model card
|
||||
|
||||
The model cards can be found at:
|
||||
* [Zamba-7B](https://huggingface.co/Zyphra/Zamba-7B-v1)
|
||||
* [Zamba-7B](MODEL_CARD_ZAMBA-7B-v1.md)
|
||||
|
||||
|
||||
## Issues
|
||||
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba-7B-v1/discussions)
|
||||
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/zyphra/zamba-7b)
|
||||
|
||||
|
||||
## License
|
||||
|
||||
@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# SpQR
|
||||
|
||||
The [SpQR](https://hf.co/papers/2306.03078) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure with sparse outliers.
|
||||
The [SpQR]((https://hf.co/papers/2306.03078)) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure with sparse outliers.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/spqr-diagram.png">
|
||||
|
||||
@ -18,7 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
[ONNX](http://onnx.ai) is an open standard that defines a common set of operators and a file format to represent deep learning models in different frameworks, including PyTorch and TensorFlow. When a model is exported to ONNX, the operators construct a computational graph (or *intermediate representation*) which represents the flow of data through the model. Standardized operators and data types makes it easy to switch between frameworks.
|
||||
|
||||
The [Optimum](https://huggingface.co/docs/optimum/index) library exports a model to ONNX with configuration objects which are supported for [many architectures](https://huggingface.co/docs/optimum/exporters/onnx/overview) and can be easily extended. If a model isn't supported, feel free to make a [contribution](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/contribute) to Optimum.
|
||||
The [Optimum](https://huggingface.co/docs/optimum/index) library exports a model to ONNX with configuration objects which are supported for [many architectures]((https://huggingface.co/docs/optimum/exporters/onnx/overview)) and can be easily extended. If a model isn't supported, feel free to make a [contribution](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/contribute) to Optimum.
|
||||
|
||||
The benefits of exporting to ONNX include the following.
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ Keypoint detection identifies and locates specific points of interest within an
|
||||
|
||||
In this guide, we will show how to extract keypoints from images.
|
||||
|
||||
For this tutorial, we will use [SuperPoint](./model_doc/superpoint), a foundation model for keypoint detection.
|
||||
For this tutorial, we will use [SuperPoint](./model_doc/superpoint.md), a foundation model for keypoint detection.
|
||||
|
||||
```python
|
||||
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
|
||||
|
||||
@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
Video-text-to-text models, also known as video language models or vision language models with video input, are language models that take a video input. These models can tackle various tasks, from video question answering to video captioning.
|
||||
|
||||
These models have nearly the same architecture as [image-text-to-text](../image_text_to_text) models except for some changes to accept video data, since video data is essentially image frames with temporal dependencies. Some image-text-to-text models take in multiple images, but this alone is inadequate for a model to accept videos. Moreover, video-text-to-text models are often trained with all vision modalities. Each example might have videos, multiple videos, images and multiple images. Some of these models can also take interleaved inputs. For example, you can refer to a specific video inside a string of text by adding a video token in text like "What is happening in this video? `<video>`".
|
||||
These models have nearly the same architecture as [image-text-to-text](../image_text_to_text.md) models except for some changes to accept video data, since video data is essentially image frames with temporal dependencies. Some image-text-to-text models take in multiple images, but this alone is inadequate for a model to accept videos. Moreover, video-text-to-text models are often trained with all vision modalities. Each example might have videos, multiple videos, images and multiple images. Some of these models can also take interleaved inputs. For example, you can refer to a specific video inside a string of text by adding a video token in text like "What is happening in this video? `<video>`".
|
||||
|
||||
In this guide, we provide a brief overview of video LMs and show how to use them with Transformers for inference.
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
[LiteRT](https://ai.google.dev/edge/litert) (previously known as TensorFlow Lite) is a high-performance runtime designed for on-device machine learning.
|
||||
|
||||
The [Optimum](https://huggingface.co/docs/optimum/index) library exports a model to LiteRT for [many architectures](https://huggingface.co/docs/optimum/exporters/onnx/overview).
|
||||
The [Optimum](https://huggingface.co/docs/optimum/index) library exports a model to LiteRT for [many architectures]((https://huggingface.co/docs/optimum/exporters/onnx/overview)).
|
||||
|
||||
The benefits of exporting to LiteRT include the following.
|
||||
|
||||
|
||||
@ -307,7 +307,7 @@ culture, and they allow us to design the'
|
||||
|
||||
アシストデコーディングを有効にするには、`assistant_model` 引数をモデルで設定します。
|
||||
|
||||
このガイドは、さまざまなデコーディング戦略を可能にする主要なパラメーターを説明しています。さらに高度なパラメーターは [`generate`] メソッドに存在し、[`generate`] メソッドの動作をさらに制御できます。使用可能なパラメーターの完全なリストについては、[APIドキュメント](./main_classes/text_generation) を参照してください。
|
||||
このガイドは、さまざまなデコーディング戦略を可能にする主要なパラメーターを説明しています。さらに高度なパラメーターは [`generate`] メソッドに存在し、[`generate`] メソッドの動作をさらに制御できます。使用可能なパラメーターの完全なリストについては、[APIドキュメント](./main_classes/text_generation.md) を参照してください。
|
||||
|
||||
|
||||
```python
|
||||
|
||||
@ -111,7 +111,7 @@ BART を始めるのに役立つ公式 Hugging Face およびコミュニティ
|
||||
- [`TFBartForConditionalGeneration`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization-tf.ipynb)。
|
||||
- [`FlaxBartForConditionalGeneration`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) でサポートされています。
|
||||
- [要約](https://huggingface.co/course/chapter7/5?fw=pt#summarization) 🤗 ハグフェイスコースの章。
|
||||
- [要約タスクガイド](../tasks/summarization)
|
||||
- [要約タスクガイド](../tasks/summarization.md)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
|
||||
@ -357,7 +357,7 @@
|
||||
title: 메인 클래스
|
||||
- sections:
|
||||
- sections:
|
||||
- local: model_doc/albert
|
||||
- local: in_translation
|
||||
title: ALBERT
|
||||
- local: in_translation
|
||||
title: Arcee
|
||||
@ -1081,7 +1081,7 @@
|
||||
title: TrOCR
|
||||
- local: in_translation
|
||||
title: TVLT
|
||||
- local: model_doc/tvp
|
||||
- local: in_translation
|
||||
title: TVP
|
||||
- local: in_translation
|
||||
title: UDOP
|
||||
@ -1155,4 +1155,4 @@
|
||||
- local: in_translation
|
||||
title: (번역중)Environment Variables
|
||||
title: Reference
|
||||
title: API
|
||||
title: API
|
||||
@ -289,7 +289,7 @@ time."\n\nHe added: "I am very proud of the work I have been able to do in the l
|
||||
culture, and they allow us to design the'
|
||||
```
|
||||
|
||||
이 가이드에서는 다양한 디코딩 전략을 가능하게 하는 주요 매개변수를 보여줍니다. [`generate`] 메서드에 대한 고급 매개변수가 존재하므로 [`generate`] 메서드의 동작을 더욱 세부적으로 제어할 수 있습니다. 사용 가능한 매개변수의 전체 목록은 [API 문서](./main_classes/text_generation)를 참조하세요.
|
||||
이 가이드에서는 다양한 디코딩 전략을 가능하게 하는 주요 매개변수를 보여줍니다. [`generate`] 메서드에 대한 고급 매개변수가 존재하므로 [`generate`] 메서드의 동작을 더욱 세부적으로 제어할 수 있습니다. 사용 가능한 매개변수의 전체 목록은 [API 문서](./main_classes/text_generation.md)를 참조하세요.
|
||||
|
||||
### 추론 디코딩(Speculative Decoding)[[speculative-decoding]]
|
||||
|
||||
|
||||
@ -21,11 +21,11 @@ GPT3/4, [Falcon](https://huggingface.co/tiiuae/falcon-40b), [Llama](https://hugg
|
||||
|
||||
이 가이드에서는 효율적인 대규모 언어 모델 배포를 위한 효과적인 기법들을 살펴보겠습니다.
|
||||
|
||||
1. **낮은 정밀도:** 연구에 따르면, [8비트와 4비트](./main_classes/quantization)와 같이 낮은 수치 정밀도로 작동하면 모델 성능의 큰 저하 없이 계산상의 이점을 얻을 수 있습니다.
|
||||
1. **낮은 정밀도:** 연구에 따르면, [8비트와 4비트](./main_classes/quantization.md)와 같이 낮은 수치 정밀도로 작동하면 모델 성능의 큰 저하 없이 계산상의 이점을 얻을 수 있습니다.
|
||||
|
||||
2. **플래시 어텐션:** 플래시 어텐션은 메모리 효율성을 높일 뿐만 아니라 최적화된 GPU 메모리 활용을 통해 효율성을 향상시키는 어텐션 알고리즘의 변형입니다.
|
||||
|
||||
3. **아키텍처 혁신:** 추론 시 대규모 언어 모델은 주로 동일한 방식(긴 입력 맥락을 가진 자기회귀 텍스트 생성 방식)으로 배포되는데, 더 효율적인 추론을 가능하게 하는 특화된 모델 아키텍처가 제안되었습니다. 이러한 모델 아키텍처의 가장 중요한 발전으로는 [Alibi](https://huggingface.co/papers/2108.12409), [Rotary embeddings](https://huggingface.co/papers/2104.09864), [Multi-Query Attention (MQA)](https://huggingface.co/papers/1911.02150), [Grouped-Query-Attention (GQA)](https://huggingface.co/papers/2305.13245)이 있습니다.
|
||||
3. **아키텍처 혁신:** 추론 시 대규모 언어 모델은 주로 동일한 방식(긴 입력 맥락을 가진 자기회귀 텍스트 생성 방식)으로 배포되는데, 더 효율적인 추론을 가능하게 하는 특화된 모델 아키텍처가 제안되었습니다. 이러한 모델 아키텍처의 가장 중요한 발전으로는 [Alibi](https://huggingface.co/papers/2108.12409), [Rotary embeddings](https://huggingface.co/papers/2104.09864), [Multi-Query Attention (MQA)](https://huggingface.co/papers/1911.02150), [Grouped-Query-Attention (GQA)]((https://huggingface.co/papers/2305.13245))이 있습니다.
|
||||
|
||||
이 가이드에서는 텐서의 관점에서 자기회귀 생성에 대한 분석을 제공합니다. 낮은 정밀도를 채택하는 것의 장단점을 논의하고, 최신 어텐션 알고리즘을 포괄적으로 탐구하며, 향상된 대규모 언어 모델 아키텍처에 대해 논합니다. 이 과정에서 각 기능의 개선 사항을 보여주는 실용적인 예제를 확인합니다.
|
||||
|
||||
@ -756,4 +756,4 @@ GQA의 가장 주목할 만한 적용 사례는 [Llama-v2](https://huggingface.c
|
||||
|
||||
연구 커뮤니티는 점점 더 큰 대규모 언어 모델의 추론 시간을 가속화하기 위한 새로운 기발한 방법들을 끊임없이 찾아내고 있습니다. 예를 들어, [추측 디코딩](https://huggingface.co/papers/2211.17192)이라는 유망한 연구 방향이 있습니다. 여기서 "쉬운 토큰"은 더 작고 빠른 언어 모델에 의해 생성되고, "어려운 토큰"만 대규모 언어 모델 자체에 의해 생성됩니다. 자세한 내용은 이 노트북의 범위를 벗어나지만, [멋진 블로그 포스트](https://huggingface.co/blog/assisted-generation)에서 읽어볼 수 있습니다.
|
||||
|
||||
GPT3/4, Llama-2-70b, Claude, PaLM과 같은 거대한 대규모 언어 모델이 [Hugging Face Chat](https://huggingface.co/chat/) 또는 ChatGPT와 같은 채팅 인터페이스에서 빠르게 실행될 수 있는 이유는 위에서 언급한 정밀도, 알고리즘, 아키텍처의 개선 덕분입니다. 앞으로 GPU, TPU 등과 같은 가속기는 점점 더 빨라지고 더 많은 메모리를 사용할 것입니다. 따라서 가장 좋은 알고리즘과 아키텍처를 사용하여 최고의 효율을 얻는 것이 중요합니다 🤗
|
||||
GPT3/4, Llama-2-70b, Claude, PaLM과 같은 거대한 대규모 언어 모델이 [Hugging Face Chat](https://huggingface.co/chat/) 또는 ChatGPT와 같은 채팅 인터페이스에서 빠르게 실행될 수 있는 이유는 위에서 언급한 정밀도, 알고리즘, 아키텍처의 개선 덕분입니다. 앞으로 GPU, TPU 등과 같은 가속기는 점점 더 빨라지고 더 많은 메모리를 사용할 것입니다. 따라서 가장 좋은 알고리즘과 아키텍처를 사용하여 최고의 효율을 얻는 것이 중요합니다 🤗
|
||||
@ -136,9 +136,9 @@ pip install -U flash-attn --no-build-isolation
|
||||
|
||||
## 양자화로 미스트랄 크기 줄이기[[shrinking-down-mistral-using-quantization]]
|
||||
|
||||
미스트랄 모델은 70억 개의 파라미터를 가지고 있어, 절반의 정밀도(float16)로 약 14GB의 GPU RAM이 필요합니다. 각 파라미터가 2바이트로 저장되기 때문입니다. 하지만 [양자화](../quantization)를 사용하면 모델 크기를 줄일 수 있습니다. 모델을 4비트(즉, 파라미터당 반 바이트)로 양자화하면 약 3.5GB의 RAM만 필요합니다.
|
||||
미스트랄 모델은 70억 개의 파라미터를 가지고 있어, 절반의 정밀도(float16)로 약 14GB의 GPU RAM이 필요합니다. 각 파라미터가 2바이트로 저장되기 때문입니다. 하지만 [양자화](../quantization.md)를 사용하면 모델 크기를 줄일 수 있습니다. 모델을 4비트(즉, 파라미터당 반 바이트)로 양자화하면 약 3.5GB의 RAM만 필요합니다.
|
||||
|
||||
모델을 양자화하는 것은 `quantization_config`를 모델에 전달하는 것만큼 간단합니다. 아래에서는 BitsAndBytes 양자화를 사용하지만, 다른 양자화 방법은 [이 페이지](../quantization)를 참고하세요:
|
||||
모델을 양자화하는 것은 `quantization_config`를 모델에 전달하는 것만큼 간단합니다. 아래에서는 BitsAndBytes 양자화를 사용하지만, 다른 양자화 방법은 [이 페이지](../quantization.md)를 참고하세요:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
@ -35,7 +35,7 @@ PatchTSMixer는 MLP-Mixer 아키텍처를 기반으로 한 경량 시계열 모
|
||||
## 사용 예[[usage-example]]
|
||||
|
||||
아래의 코드 스니펫은 PatchTSMixer 모델을 무작위로 초기화하는 방법을 보여줍니다.
|
||||
PatchTSMixer 모델은 [Trainer API](../trainer)와 호환됩니다.
|
||||
PatchTSMixer 모델은 [Trainer API](../trainer.md)와 호환됩니다.
|
||||
|
||||
```python
|
||||
|
||||
|
||||
@ -57,8 +57,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
| 模型 | PyTorch 支持 | TensorFlow 支持 | Flax 支持 |
|
||||
|:------------------------------------------------------------------------:|:---------------:|:------------------:|:------------:|
|
||||
| [ALBERT](../en/model_doc/albert) | ✅ | ✅ | ✅ |
|
||||
| [ALIGN](../en/model_doc/align) | ✅ | ❌ | ❌ |
|
||||
| [ALBERT](../en/model_doc/albert.md) | ✅ | ✅ | ✅ |
|
||||
| [ALIGN](../en/model_doc/align.md) | ✅ | ❌ | ❌ |
|
||||
| [AltCLIP](../en/model_doc/altclip) | ✅ | ❌ | ❌ |
|
||||
| [Audio Spectrogram Transformer](../en/model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
|
||||
| [Autoformer](../en/model_doc/autoformer) | ✅ | ❌ | ❌ |
|
||||
|
||||
@ -202,7 +202,7 @@ def replace_batch_norm(model):
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
new_module = TestDetrFrozenBatchNorm2d(module.num_features)
|
||||
|
||||
if module.weight.device != torch.device("meta"):
|
||||
if not module.weight.device == torch.device("meta"):
|
||||
new_module.weight.data.copy_(module.weight)
|
||||
new_module.bias.data.copy_(module.bias)
|
||||
new_module.running_mean.data.copy_(module.running_mean)
|
||||
|
||||
@ -10,27 +10,26 @@ from transformers.generation import GenerationConfig
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="paged_attention|kernels-community/flash-attn",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=512,
|
||||
# use_cuda_graph=False,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=False,
|
||||
use_cache=False,
|
||||
num_blocks=2048,
|
||||
block_size=128,
|
||||
do_sample=True,
|
||||
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
|
||||
scheduler="prefill_first",
|
||||
)
|
||||
|
||||
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version
|
||||
|
||||
# --- Example 1: Simple Version using generate_batch ---
|
||||
print("--- Running CB Generation Example ---")
|
||||
|
||||
|
||||
@ -42,21 +41,19 @@ tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
|
||||
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||
|
||||
start_time_simple = time.time()
|
||||
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
||||
batch_outputs = model.generate_batch(
|
||||
inputs=simple_batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
end_time_simple = time.time()
|
||||
token_count = 0
|
||||
|
||||
for request in batch_outputs:
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
try:
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||
except Exception as e:
|
||||
print(f"Decoding failed for request {request}: {e}")
|
||||
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
@ -68,9 +65,7 @@ print("-" * 20)
|
||||
print("--- Finished CB Generation Example ---\n\n")
|
||||
|
||||
|
||||
print(
|
||||
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s"
|
||||
)
|
||||
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
|
||||
|
||||
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
||||
|
||||
@ -155,7 +155,7 @@ accelerate launch run_semantic_segmentation_no_trainer.py --output_dir segformer
|
||||
|
||||
and boom, you're training, possibly on multiple GPUs, logging everything to all trackers found in your environment (like Weights and Biases, Tensorboard) and regularly pushing your model to the hub (with the repo name being equal to `args.output_dir` at your HF username) 🤗
|
||||
|
||||
With the default settings, the script fine-tunes a [SegFormer](https://huggingface.co/docs/transformers/main/en/model_doc/segformer) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset.
|
||||
With the default settings, the script fine-tunes a [SegFormer]((https://huggingface.co/docs/transformers/main/en/model_doc/segformer)) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset.
|
||||
|
||||
The resulting model can be seen here: https://huggingface.co/nielsr/segformer-finetuned-sidewalk. Note that the script usually requires quite a few epochs to achieve great results, e.g. the SegFormer authors fine-tuned their model for 160k steps (batches) on [`scene_parse_150`](https://huggingface.co/datasets/scene_parse_150).
|
||||
|
||||
|
||||
@ -495,7 +495,7 @@ def main():
|
||||
|
||||
# region Training and validation
|
||||
if training_args.do_train:
|
||||
if training_args.do_eval and data_args.task_name != "mnli":
|
||||
if training_args.do_eval and not data_args.task_name == "mnli":
|
||||
# Do both evaluation and training in the Keras fit loop, unless the task is MNLI
|
||||
# because MNLI has two validation sets
|
||||
validation_data = tf_data["validation"]
|
||||
|
||||
4
setup.py
4
setup.py
@ -128,7 +128,7 @@ _deps = [
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
"kernels>=0.6.1,<=0.9",
|
||||
"kernels>=0.6.1,<0.7",
|
||||
"librosa",
|
||||
"natten>=0.14.6,<0.15.0",
|
||||
"nltk<=3.8.1",
|
||||
@ -137,7 +137,7 @@ _deps = [
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"openai>=1.98.0",
|
||||
"openai",
|
||||
"opencv-python",
|
||||
"optimum-benchmark>=0.3.0",
|
||||
"optuna",
|
||||
|
||||
@ -28,31 +28,27 @@ from . import dependency_versions_check
|
||||
from .utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_bitsandbytes_available,
|
||||
is_essentia_available,
|
||||
is_flax_available,
|
||||
is_g2p_en_available,
|
||||
is_keras_nlp_available,
|
||||
is_librosa_available,
|
||||
is_mistral_common_available,
|
||||
is_pretty_midi_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_speech_available,
|
||||
is_tensorflow_text_available,
|
||||
is_tf_available,
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torchaudio_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
# Note: the following symbols are deliberately exported with `as`
|
||||
# so that mypy, pylint or other static linters can recognize them,
|
||||
# given that they are not exported using `__all__` in this file.
|
||||
from .utils import is_bitsandbytes_available as is_bitsandbytes_available
|
||||
from .utils import is_flax_available as is_flax_available
|
||||
from .utils import is_keras_nlp_available as is_keras_nlp_available
|
||||
from .utils import is_scipy_available as is_scipy_available
|
||||
from .utils import is_sentencepiece_available as is_sentencepiece_available
|
||||
from .utils import is_speech_available as is_speech_available
|
||||
from .utils import is_tensorflow_text_available as is_tensorflow_text_available
|
||||
from .utils import is_tf_available as is_tf_available
|
||||
from .utils import is_timm_available as is_timm_available
|
||||
from .utils import is_tokenizers_available as is_tokenizers_available
|
||||
from .utils import is_torch_available as is_torch_available
|
||||
from .utils import is_torchaudio_available as is_torchaudio_available
|
||||
from .utils import is_torchvision_available as is_torchvision_available
|
||||
from .utils import is_vision_available as is_vision_available
|
||||
from .utils import logging as logging
|
||||
from .utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
@ -277,7 +273,6 @@ _import_structure = {
|
||||
"GPTQConfig",
|
||||
"HiggsConfig",
|
||||
"HqqConfig",
|
||||
"Mxfp4Config",
|
||||
"QuantoConfig",
|
||||
"QuarkConfig",
|
||||
"FPQuantConfig",
|
||||
@ -584,363 +579,400 @@ else:
|
||||
# Direct imports for type-checking
|
||||
if TYPE_CHECKING:
|
||||
# All modeling imports
|
||||
from .cache_utils import Cache as Cache
|
||||
from .cache_utils import CacheConfig as CacheConfig
|
||||
from .cache_utils import DynamicCache as DynamicCache
|
||||
from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
|
||||
from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
|
||||
from .cache_utils import HybridCache as HybridCache
|
||||
from .cache_utils import MambaCache as MambaCache
|
||||
from .cache_utils import OffloadedCache as OffloadedCache
|
||||
from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
|
||||
from .cache_utils import QuantizedCache as QuantizedCache
|
||||
from .cache_utils import QuantizedCacheConfig as QuantizedCacheConfig
|
||||
from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
|
||||
from .cache_utils import SinkCache as SinkCache
|
||||
from .cache_utils import SlidingWindowCache as SlidingWindowCache
|
||||
from .cache_utils import StaticCache as StaticCache
|
||||
from .configuration_utils import PretrainedConfig as PretrainedConfig
|
||||
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
|
||||
from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
|
||||
from .cache_utils import (
|
||||
Cache,
|
||||
CacheConfig,
|
||||
DynamicCache,
|
||||
EncoderDecoderCache,
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
MambaCache,
|
||||
OffloadedCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
SinkCache,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .convert_slow_tokenizer import (
|
||||
SLOW_TO_FAST_CONVERTERS,
|
||||
convert_slow_tokenizer,
|
||||
)
|
||||
|
||||
# Data
|
||||
from .data import DataProcessor as DataProcessor
|
||||
from .data import InputExample as InputExample
|
||||
from .data import InputFeatures as InputFeatures
|
||||
from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor
|
||||
from .data import SquadExample as SquadExample
|
||||
from .data import SquadFeatures as SquadFeatures
|
||||
from .data import SquadV1Processor as SquadV1Processor
|
||||
from .data import SquadV2Processor as SquadV2Processor
|
||||
from .data import glue_compute_metrics as glue_compute_metrics
|
||||
from .data import glue_convert_examples_to_features as glue_convert_examples_to_features
|
||||
from .data import glue_output_modes as glue_output_modes
|
||||
from .data import glue_processors as glue_processors
|
||||
from .data import glue_tasks_num_labels as glue_tasks_num_labels
|
||||
from .data import squad_convert_examples_to_features as squad_convert_examples_to_features
|
||||
from .data import xnli_compute_metrics as xnli_compute_metrics
|
||||
from .data import xnli_output_modes as xnli_output_modes
|
||||
from .data import xnli_processors as xnli_processors
|
||||
from .data import xnli_tasks_num_labels as xnli_tasks_num_labels
|
||||
from .data.data_collator import DataCollator as DataCollator
|
||||
from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling
|
||||
from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice
|
||||
from .data.data_collator import (
|
||||
DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling,
|
||||
from .data import (
|
||||
DataProcessor,
|
||||
InputExample,
|
||||
InputFeatures,
|
||||
SingleSentenceClassificationProcessor,
|
||||
SquadExample,
|
||||
SquadFeatures,
|
||||
SquadV1Processor,
|
||||
SquadV2Processor,
|
||||
glue_compute_metrics,
|
||||
glue_convert_examples_to_features,
|
||||
glue_output_modes,
|
||||
glue_processors,
|
||||
glue_tasks_num_labels,
|
||||
squad_convert_examples_to_features,
|
||||
xnli_compute_metrics,
|
||||
xnli_output_modes,
|
||||
xnli_processors,
|
||||
xnli_tasks_num_labels,
|
||||
)
|
||||
from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq
|
||||
from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP
|
||||
from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification
|
||||
from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask
|
||||
from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening
|
||||
from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding
|
||||
from .data.data_collator import DefaultDataCollator as DefaultDataCollator
|
||||
from .data.data_collator import default_data_collator as default_data_collator
|
||||
from .data.datasets import GlueDataset as GlueDataset
|
||||
from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments
|
||||
from .data.datasets import LineByLineTextDataset as LineByLineTextDataset
|
||||
from .data.datasets import LineByLineWithRefDataset as LineByLineWithRefDataset
|
||||
from .data.datasets import LineByLineWithSOPTextDataset as LineByLineWithSOPTextDataset
|
||||
from .data.datasets import SquadDataset as SquadDataset
|
||||
from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments
|
||||
from .data.datasets import TextDataset as TextDataset
|
||||
from .data.datasets import TextDatasetForNextSentencePrediction as TextDatasetForNextSentencePrediction
|
||||
from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor
|
||||
from .data.data_collator import (
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForMultipleChoice,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorWithPadding,
|
||||
DefaultDataCollator,
|
||||
default_data_collator,
|
||||
)
|
||||
from .data.datasets import (
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
LineByLineWithRefDataset,
|
||||
LineByLineWithSOPTextDataset,
|
||||
SquadDataset,
|
||||
SquadDataTrainingArguments,
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
|
||||
# Feature Extractor
|
||||
from .feature_extraction_utils import BatchFeature as BatchFeature
|
||||
from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
|
||||
# Generation
|
||||
from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor
|
||||
from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer
|
||||
from .generation import BayesianDetectorConfig as BayesianDetectorConfig
|
||||
from .generation import BayesianDetectorModel as BayesianDetectorModel
|
||||
from .generation import BeamScorer as BeamScorer
|
||||
from .generation import BeamSearchScorer as BeamSearchScorer
|
||||
from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
|
||||
from .generation import CompileConfig as CompileConfig
|
||||
from .generation import ConstrainedBeamSearchScorer as ConstrainedBeamSearchScorer
|
||||
from .generation import Constraint as Constraint
|
||||
from .generation import ConstraintListState as ConstraintListState
|
||||
from .generation import DisjunctiveConstraint as DisjunctiveConstraint
|
||||
from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
|
||||
from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
|
||||
from .generation import EosTokenCriteria as EosTokenCriteria
|
||||
from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper
|
||||
from .generation import EtaLogitsWarper as EtaLogitsWarper
|
||||
from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty
|
||||
from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor
|
||||
from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor
|
||||
from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor
|
||||
from .generation import FlaxGenerationMixin as FlaxGenerationMixin
|
||||
from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor
|
||||
from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList
|
||||
from .generation import FlaxLogitsWarper as FlaxLogitsWarper
|
||||
from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor
|
||||
from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor
|
||||
from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor
|
||||
from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper
|
||||
from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper
|
||||
from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper
|
||||
from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor
|
||||
from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor
|
||||
from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor
|
||||
from .generation import GenerationConfig as GenerationConfig
|
||||
from .generation import GenerationMixin as GenerationMixin
|
||||
from .generation import HammingDiversityLogitsProcessor as HammingDiversityLogitsProcessor
|
||||
from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor
|
||||
from .generation import LogitNormalization as LogitNormalization
|
||||
from .generation import LogitsProcessor as LogitsProcessor
|
||||
from .generation import LogitsProcessorList as LogitsProcessorList
|
||||
from .generation import MaxLengthCriteria as MaxLengthCriteria
|
||||
from .generation import MaxTimeCriteria as MaxTimeCriteria
|
||||
from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor
|
||||
from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor
|
||||
from .generation import MinPLogitsWarper as MinPLogitsWarper
|
||||
from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor
|
||||
from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor
|
||||
from .generation import PhrasalConstraint as PhrasalConstraint
|
||||
from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor
|
||||
from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor
|
||||
from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor
|
||||
from .generation import StoppingCriteria as StoppingCriteria
|
||||
from .generation import StoppingCriteriaList as StoppingCriteriaList
|
||||
from .generation import StopStringCriteria as StopStringCriteria
|
||||
from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor
|
||||
from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor
|
||||
from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector
|
||||
from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig
|
||||
from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor
|
||||
from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
|
||||
from .generation import TextIteratorStreamer as TextIteratorStreamer
|
||||
from .generation import TextStreamer as TextStreamer
|
||||
from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor
|
||||
from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor
|
||||
from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor
|
||||
from .generation import TFGenerationMixin as TFGenerationMixin
|
||||
from .generation import TFLogitsProcessor as TFLogitsProcessor
|
||||
from .generation import TFLogitsProcessorList as TFLogitsProcessorList
|
||||
from .generation import TFLogitsWarper as TFLogitsWarper
|
||||
from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor
|
||||
from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor
|
||||
from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor
|
||||
from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor
|
||||
from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor
|
||||
from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor
|
||||
from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper
|
||||
from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper
|
||||
from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper
|
||||
from .generation import TopKLogitsWarper as TopKLogitsWarper
|
||||
from .generation import TopPLogitsWarper as TopPLogitsWarper
|
||||
from .generation import TypicalLogitsWarper as TypicalLogitsWarper
|
||||
from .generation import (
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
AsyncTextIteratorStreamer,
|
||||
BayesianDetectorConfig,
|
||||
BayesianDetectorModel,
|
||||
BeamScorer,
|
||||
BeamSearchScorer,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
CompileConfig,
|
||||
ConstrainedBeamSearchScorer,
|
||||
Constraint,
|
||||
ConstraintListState,
|
||||
DisjunctiveConstraint,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EosTokenCriteria,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxGenerationMixin,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
FlaxWhisperTimeStampLogitsProcessor,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
MinLengthLogitsProcessor,
|
||||
MinNewTokensLengthLogitsProcessor,
|
||||
MinPLogitsWarper,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PhrasalConstraint,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
StopStringCriteria,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
SynthIDTextWatermarkDetector,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
SynthIDTextWatermarkLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TextIteratorStreamer,
|
||||
TextStreamer,
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFGenerationMixin,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFLogitsWarper,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkDetector,
|
||||
WatermarkingConfig,
|
||||
WatermarkLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .generation import WatermarkDetector as WatermarkDetector
|
||||
from .generation import WatermarkingConfig as WatermarkingConfig
|
||||
from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor
|
||||
from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor
|
||||
from .hf_argparser import HfArgumentParser as HfArgumentParser
|
||||
from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin
|
||||
from .image_processing_utils import BaseImageProcessor as BaseImageProcessor
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast as BaseImageProcessorFast
|
||||
from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin
|
||||
from .hf_argparser import HfArgumentParser
|
||||
from .image_processing_base import ImageProcessingMixin
|
||||
from .image_processing_utils import BaseImageProcessor
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||
from .image_utils import ImageFeatureExtractionMixin
|
||||
|
||||
# Integrations
|
||||
from .integrations import is_clearml_available as is_clearml_available
|
||||
from .integrations import is_comet_available as is_comet_available
|
||||
from .integrations import is_dvclive_available as is_dvclive_available
|
||||
from .integrations import is_neptune_available as is_neptune_available
|
||||
from .integrations import is_optuna_available as is_optuna_available
|
||||
from .integrations import is_ray_available as is_ray_available
|
||||
from .integrations import is_ray_tune_available as is_ray_tune_available
|
||||
from .integrations import is_sigopt_available as is_sigopt_available
|
||||
from .integrations import is_swanlab_available as is_swanlab_available
|
||||
from .integrations import is_tensorboard_available as is_tensorboard_available
|
||||
from .integrations import is_trackio_available as is_trackio_available
|
||||
from .integrations import is_wandb_available as is_wandb_available
|
||||
from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache
|
||||
from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache
|
||||
from .keras_callbacks import KerasMetricCallback as KerasMetricCallback
|
||||
from .keras_callbacks import PushToHubCallback as PushToHubCallback
|
||||
from .masking_utils import AttentionMaskInterface as AttentionMaskInterface
|
||||
from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context
|
||||
from .integrations import (
|
||||
is_clearml_available,
|
||||
is_comet_available,
|
||||
is_dvclive_available,
|
||||
is_neptune_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_trackio_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
from .masking_utils import AttentionMaskInterface
|
||||
from .model_debugging_utils import (
|
||||
model_addition_debugger_context,
|
||||
)
|
||||
|
||||
# Model Cards
|
||||
from .modelcard import ModelCard as ModelCard
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel
|
||||
from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer
|
||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS
|
||||
from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update
|
||||
from .modelcard import ModelCard
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .modeling_layers import GradientCheckpointingLayer
|
||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
|
||||
# TF 2.0 <=> PyTorch conversion utilities
|
||||
from .modeling_tf_pytorch_utils import (
|
||||
convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name,
|
||||
convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
)
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model
|
||||
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model
|
||||
from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model
|
||||
from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model
|
||||
from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel
|
||||
from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary
|
||||
from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings
|
||||
from .modeling_tf_utils import shape_list as shape_list
|
||||
from .modeling_utils import AttentionInterface as AttentionInterface
|
||||
from .modeling_utils import PreTrainedModel as PreTrainedModel
|
||||
from .modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
shape_list,
|
||||
)
|
||||
from .modeling_utils import AttentionInterface, PreTrainedModel
|
||||
from .models import *
|
||||
from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor
|
||||
from .models.timm_wrapper import TimmWrapperImageProcessor
|
||||
|
||||
# Optimization
|
||||
from .optimization import Adafactor as Adafactor
|
||||
from .optimization import get_constant_schedule as get_constant_schedule
|
||||
from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup
|
||||
from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup
|
||||
from .optimization import (
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
Adafactor,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_inverse_sqrt_schedule,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
get_wsd_schedule,
|
||||
)
|
||||
from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule
|
||||
from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup
|
||||
from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup
|
||||
from .optimization import get_scheduler as get_scheduler
|
||||
from .optimization import get_wsd_schedule as get_wsd_schedule
|
||||
|
||||
# Optimization
|
||||
from .optimization_tf import AdamWeightDecay as AdamWeightDecay
|
||||
from .optimization_tf import GradientAccumulator as GradientAccumulator
|
||||
from .optimization_tf import WarmUp as WarmUp
|
||||
from .optimization_tf import create_optimizer as create_optimizer
|
||||
from .optimization_tf import (
|
||||
AdamWeightDecay,
|
||||
GradientAccumulator,
|
||||
WarmUp,
|
||||
create_optimizer,
|
||||
)
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline
|
||||
from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline
|
||||
from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat
|
||||
from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline
|
||||
from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline
|
||||
from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline
|
||||
from .pipelines import FillMaskPipeline as FillMaskPipeline
|
||||
from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline
|
||||
from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline
|
||||
from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline
|
||||
from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline
|
||||
from .pipelines import ImageToImagePipeline as ImageToImagePipeline
|
||||
from .pipelines import ImageToTextPipeline as ImageToTextPipeline
|
||||
from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
|
||||
from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
|
||||
from .pipelines import NerPipeline as NerPipeline
|
||||
from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
|
||||
from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat
|
||||
from .pipelines import Pipeline as Pipeline
|
||||
from .pipelines import PipelineDataFormat as PipelineDataFormat
|
||||
from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline
|
||||
from .pipelines import SummarizationPipeline as SummarizationPipeline
|
||||
from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline
|
||||
from .pipelines import Text2TextGenerationPipeline as Text2TextGenerationPipeline
|
||||
from .pipelines import TextClassificationPipeline as TextClassificationPipeline
|
||||
from .pipelines import TextGenerationPipeline as TextGenerationPipeline
|
||||
from .pipelines import TextToAudioPipeline as TextToAudioPipeline
|
||||
from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline
|
||||
from .pipelines import TranslationPipeline as TranslationPipeline
|
||||
from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline
|
||||
from .pipelines import VisualQuestionAnsweringPipeline as VisualQuestionAnsweringPipeline
|
||||
from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline
|
||||
from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline
|
||||
from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline
|
||||
from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline
|
||||
from .pipelines import pipeline as pipeline
|
||||
from .processing_utils import ProcessorMixin as ProcessorMixin
|
||||
from .pytorch_utils import Conv1D as Conv1D
|
||||
from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward
|
||||
from .pytorch_utils import prune_layer as prune_layer
|
||||
from .pipelines import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
CsvPipelineDataFormat,
|
||||
DepthEstimationPipeline,
|
||||
DocumentQuestionAnsweringPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageFeatureExtractionPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageTextToTextPipeline,
|
||||
ImageToImagePipeline,
|
||||
ImageToTextPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
MaskGenerationPipeline,
|
||||
NerPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
PipedPipelineDataFormat,
|
||||
Pipeline,
|
||||
PipelineDataFormat,
|
||||
QuestionAnsweringPipeline,
|
||||
SummarizationPipeline,
|
||||
TableQuestionAnsweringPipeline,
|
||||
Text2TextGenerationPipeline,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TextToAudioPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
VideoClassificationPipeline,
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotAudioClassificationPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from .processing_utils import ProcessorMixin
|
||||
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
|
||||
|
||||
# Tokenization
|
||||
from .tokenization_utils import PreTrainedTokenizer as PreTrainedTokenizer
|
||||
from .tokenization_utils_base import AddedToken as AddedToken
|
||||
from .tokenization_utils_base import BatchEncoding as BatchEncoding
|
||||
from .tokenization_utils_base import CharSpan as CharSpan
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase
|
||||
from .tokenization_utils_base import SpecialTokensMixin as SpecialTokensMixin
|
||||
from .tokenization_utils_base import TokenSpan as TokenSpan
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast as PreTrainedTokenizerFast
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
CharSpan,
|
||||
PreTrainedTokenizerBase,
|
||||
SpecialTokensMixin,
|
||||
TokenSpan,
|
||||
)
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
# Trainer
|
||||
from .trainer import Trainer as Trainer
|
||||
from .trainer import Trainer
|
||||
|
||||
# Trainer
|
||||
from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback
|
||||
from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback
|
||||
from .trainer_callback import PrinterCallback as PrinterCallback
|
||||
from .trainer_callback import ProgressCallback as ProgressCallback
|
||||
from .trainer_callback import TrainerCallback as TrainerCallback
|
||||
from .trainer_callback import TrainerControl as TrainerControl
|
||||
from .trainer_callback import TrainerState as TrainerState
|
||||
from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first
|
||||
from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer
|
||||
from .trainer_utils import EvalPrediction as EvalPrediction
|
||||
from .trainer_utils import IntervalStrategy as IntervalStrategy
|
||||
from .trainer_utils import SchedulerType as SchedulerType
|
||||
from .trainer_utils import enable_full_determinism as enable_full_determinism
|
||||
from .trainer_utils import set_seed as set_seed
|
||||
from .training_args import TrainingArguments as TrainingArguments
|
||||
from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments as TFTrainingArguments
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
EarlyStoppingCallback,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_pt_utils import torch_distributed_zero_first
|
||||
from .trainer_seq2seq import Seq2SeqTrainer
|
||||
from .trainer_utils import (
|
||||
EvalPrediction,
|
||||
IntervalStrategy,
|
||||
SchedulerType,
|
||||
enable_full_determinism,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
# Files and general utilities
|
||||
from .utils import CONFIG_NAME as CONFIG_NAME
|
||||
from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME
|
||||
from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE
|
||||
from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE
|
||||
from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME
|
||||
from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME
|
||||
from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE
|
||||
from .utils import WEIGHTS_NAME as WEIGHTS_NAME
|
||||
from .utils import TensorType as TensorType
|
||||
from .utils import add_end_docstrings as add_end_docstrings
|
||||
from .utils import add_start_docstrings as add_start_docstrings
|
||||
from .utils import is_apex_available as is_apex_available
|
||||
from .utils import is_av_available as is_av_available
|
||||
from .utils import is_datasets_available as is_datasets_available
|
||||
from .utils import is_faiss_available as is_faiss_available
|
||||
from .utils import is_matplotlib_available as is_matplotlib_available
|
||||
from .utils import is_phonemizer_available as is_phonemizer_available
|
||||
from .utils import is_psutil_available as is_psutil_available
|
||||
from .utils import is_py3nvml_available as is_py3nvml_available
|
||||
from .utils import is_pyctcdecode_available as is_pyctcdecode_available
|
||||
from .utils import is_sacremoses_available as is_sacremoses_available
|
||||
from .utils import is_safetensors_available as is_safetensors_available
|
||||
from .utils import is_sklearn_available as is_sklearn_available
|
||||
from .utils import is_torch_hpu_available as is_torch_hpu_available
|
||||
from .utils import is_torch_mlu_available as is_torch_mlu_available
|
||||
from .utils import is_torch_musa_available as is_torch_musa_available
|
||||
from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available
|
||||
from .utils import is_torch_npu_available as is_torch_npu_available
|
||||
from .utils import is_torch_xla_available as is_torch_xla_available
|
||||
from .utils import is_torch_xpu_available as is_torch_xpu_available
|
||||
from .utils import logging as logging
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
MODEL_CARD_NAME,
|
||||
PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
PYTORCH_TRANSFORMERS_CACHE,
|
||||
SPIECE_UNDERLINE,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
TRANSFORMERS_CACHE,
|
||||
WEIGHTS_NAME,
|
||||
TensorType,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
is_apex_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_keras_nlp_available,
|
||||
is_matplotlib_available,
|
||||
is_phonemizer_available,
|
||||
is_psutil_available,
|
||||
is_py3nvml_available,
|
||||
is_pyctcdecode_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_sklearn_available,
|
||||
is_speech_available,
|
||||
is_tensorflow_text_available,
|
||||
is_tf_available,
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_hpu_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
# bitsandbytes config
|
||||
from .utils.quantization_config import AqlmConfig as AqlmConfig
|
||||
from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig
|
||||
from .utils.quantization_config import AwqConfig as AwqConfig
|
||||
from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig
|
||||
from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig
|
||||
from .utils.quantization_config import EetqConfig as EetqConfig
|
||||
from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config
|
||||
from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config
|
||||
from .utils.quantization_config import FPQuantConfig as FPQuantConfig
|
||||
from .utils.quantization_config import GPTQConfig as GPTQConfig
|
||||
from .utils.quantization_config import HiggsConfig as HiggsConfig
|
||||
from .utils.quantization_config import HqqConfig as HqqConfig
|
||||
from .utils.quantization_config import QuantoConfig as QuantoConfig
|
||||
from .utils.quantization_config import QuarkConfig as QuarkConfig
|
||||
from .utils.quantization_config import SpQRConfig as SpQRConfig
|
||||
from .utils.quantization_config import TorchAoConfig as TorchAoConfig
|
||||
from .utils.quantization_config import VptqConfig as VptqConfig
|
||||
from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
|
||||
from .utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AutoRoundConfig,
|
||||
AwqConfig,
|
||||
BitNetQuantConfig,
|
||||
BitsAndBytesConfig,
|
||||
CompressedTensorsConfig,
|
||||
EetqConfig,
|
||||
FbgemmFp8Config,
|
||||
FineGrainedFP8Config,
|
||||
FPQuantConfig,
|
||||
GPTQConfig,
|
||||
HiggsConfig,
|
||||
HqqConfig,
|
||||
QuantoConfig,
|
||||
QuarkConfig,
|
||||
SpQRConfig,
|
||||
TorchAoConfig,
|
||||
VptqConfig,
|
||||
)
|
||||
from .video_processing_utils import BaseVideoProcessor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@ -1164,7 +1164,7 @@ class Cache:
|
||||
while len(self.layers) <= layer_idx:
|
||||
kwargs = self.layer_init_kwargs.copy()
|
||||
if self.layer_init_kwargs.get("layer_device_map", None) is not None:
|
||||
kwargs["device"] = kwargs.pop("layer_device_map")[len(self.layers)]
|
||||
kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx]
|
||||
|
||||
new_layer_class = (
|
||||
self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
|
||||
|
||||
@ -396,9 +396,6 @@ class ServeArguments:
|
||||
log_level: str = field(
|
||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||
)
|
||||
default_seed: Optional[int] = field(
|
||||
default=None, metadata={"help": "The default seed for torch, should be an integer."}
|
||||
)
|
||||
enable_cors: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -454,9 +451,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||
self.enable_cors = self.args.enable_cors
|
||||
|
||||
if self.args.default_seed is not None:
|
||||
torch.manual_seed(self.args.default_seed)
|
||||
|
||||
# Set up logging
|
||||
transformers_logger = logging.get_logger("transformers")
|
||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
@ -905,16 +899,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
inputs = inputs.to(model.device)
|
||||
request_id = req.get("request_id", "req_0")
|
||||
|
||||
# Temporary hack for GPTOSS 1: don't filter special tokens
|
||||
skip_special_tokens = True
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
skip_special_tokens = False
|
||||
|
||||
generation_streamer = TextIteratorStreamer(
|
||||
processor,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
skip_prompt=True,
|
||||
)
|
||||
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||
|
||||
last_kv_cache = None
|
||||
@ -930,21 +915,12 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
}
|
||||
|
||||
def stream_chat_completion(streamer, _request_id):
|
||||
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
|
||||
# classes and piping the reasoning trace into a new field
|
||||
filter_cot = False
|
||||
cot_trace_end = None
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
filter_cot = True
|
||||
cot_trace_end = "<|channel|>final<|message|>"
|
||||
|
||||
# Thin wrapper to save the KV cache after generation
|
||||
def generate_with_cache(**kwargs):
|
||||
generate_output = model.generate(**kwargs)
|
||||
self.last_kv_cache = generate_output.past_key_values
|
||||
|
||||
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||
results = ""
|
||||
|
||||
try:
|
||||
thread.start()
|
||||
@ -955,20 +931,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
|
||||
|
||||
for result in streamer:
|
||||
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
if result.endswith("<|return|>"):
|
||||
result = result[: -len("<|return|>")]
|
||||
results += result
|
||||
|
||||
# (related to temporary hack 2)
|
||||
if filter_cot:
|
||||
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
|
||||
filter_cot = False
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# ====== TOOL CALL LOGIC ======
|
||||
if tool_model_family is not None:
|
||||
# Start of a tool call: reset state variables, set `inside_tool_call`
|
||||
@ -1070,38 +1032,10 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
self.last_model = model_id_and_revision
|
||||
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||
|
||||
if isinstance(req["input"], str):
|
||||
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||
inputs.append({"role": "user", "content": req["input"]})
|
||||
elif isinstance(req["input"], list):
|
||||
if "instructions" in req:
|
||||
if req["input"][0]["role"] != "system":
|
||||
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
|
||||
else:
|
||||
inputs = req["input"]
|
||||
inputs[0]["content"] = req["instructions"]
|
||||
else:
|
||||
inputs = req["input"]
|
||||
elif isinstance(req["input"], dict):
|
||||
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||
inputs.append(req["input"])
|
||||
else:
|
||||
raise ValueError("inputs should be a list, dict, or str")
|
||||
|
||||
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device)
|
||||
request_id = req.get("previous_response_id", "req_0")
|
||||
|
||||
# Temporary hack for GPTOSS 1: don't filter special tokens
|
||||
skip_special_tokens = True
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
skip_special_tokens = False
|
||||
|
||||
generation_streamer = TextIteratorStreamer(
|
||||
processor,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
skip_prompt=True,
|
||||
)
|
||||
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||
|
||||
last_kv_cache = None
|
||||
@ -1118,14 +1052,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
}
|
||||
|
||||
def stream_response(streamer, _request_id):
|
||||
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
|
||||
# classes and piping the reasoning trace into a new field
|
||||
filter_cot = False
|
||||
cot_trace_end = None
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
filter_cot = True
|
||||
cot_trace_end = "<|channel|>final<|message|>"
|
||||
|
||||
# Thin wrapper to save the KV cache after generation
|
||||
def generate_with_cache(**kwargs):
|
||||
generate_output = model.generate(**kwargs)
|
||||
@ -1212,21 +1138,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
# Stream the actual generated text
|
||||
results = ""
|
||||
for result in streamer:
|
||||
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
|
||||
if "gptoss" in model.config.architectures[0].lower():
|
||||
if result.endswith("<|return|>"):
|
||||
result = result[: -len("<|return|>")]
|
||||
results += result
|
||||
|
||||
# (related to temporary hack 2)
|
||||
if filter_cot:
|
||||
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
|
||||
filter_cot = False
|
||||
results = "" # reset the results -> results will now track the final response
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
response_output_text_delta = ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
item_id=f"msg_{request_id}",
|
||||
@ -1234,7 +1146,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
output_index=output_index,
|
||||
content_index=content_index,
|
||||
delta=result,
|
||||
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
|
||||
)
|
||||
sequence_number += 1
|
||||
yield self.build_response_event(response_output_text_delta)
|
||||
@ -1247,7 +1158,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
output_index=output_index,
|
||||
content_index=0,
|
||||
text=results,
|
||||
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
|
||||
)
|
||||
sequence_number += 1
|
||||
yield self.build_response_event(response_output_text_done)
|
||||
@ -1507,10 +1417,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
}
|
||||
if quantization_config is not None:
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
architecture = getattr(transformers, config.architectures[0])
|
||||
|
||||
@ -248,7 +248,7 @@ class DebugUnderflowOverflow:
|
||||
|
||||
last_frame_of_batch = False
|
||||
|
||||
trace_mode = self.batch_number in self.trace_batch_nums
|
||||
trace_mode = True if self.batch_number in self.trace_batch_nums else False
|
||||
if trace_mode:
|
||||
self.reset_saved_frames()
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ deps = {
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"kernels": "kernels>=0.6.1,<=0.9",
|
||||
"kernels": "kernels>=0.6.1,<0.7",
|
||||
"librosa": "librosa",
|
||||
"natten": "natten>=0.14.6,<0.15.0",
|
||||
"nltk": "nltk<=3.8.1",
|
||||
@ -43,7 +43,7 @@ deps = {
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"openai": "openai>=1.98.0",
|
||||
"openai": "openai",
|
||||
"opencv-python": "opencv-python",
|
||||
"optimum-benchmark": "optimum-benchmark>=0.3.0",
|
||||
"optuna": "optuna",
|
||||
|
||||
@ -228,7 +228,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if batch_size != (input_ids.shape[0] // self.group_size):
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
if self.num_beam_groups > 1:
|
||||
raise ValueError(
|
||||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
||||
@ -564,7 +564,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps)
|
||||
if batch_size != (input_ids.shape[0] // self.group_size):
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
if self.num_beam_groups > 1:
|
||||
raise ValueError(
|
||||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
||||
|
||||
@ -13,7 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import queue
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@ -25,13 +27,13 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..generation.configuration_utils import GenerationConfig
|
||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ..utils.logging import logging
|
||||
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
||||
|
||||
|
||||
@ -47,7 +49,9 @@ class RequestStatus(Enum):
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Setup your logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -155,8 +159,8 @@ class PagedAttentionCache:
|
||||
generation_config: GenerationConfig,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_requests: int = 100,
|
||||
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
||||
initial_prompt_shapes: Optional[list[list[int]]] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize a paged attention cache for efficient memory usage.
|
||||
@ -175,6 +179,23 @@ class PagedAttentionCache:
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# Calculate optimal block size and number if not provided
|
||||
num_blocks = getattr(generation_config, "num_blocks", None)
|
||||
block_size = getattr(generation_config, "block_size", None)
|
||||
if num_blocks is None or block_size is None:
|
||||
logger.info("Calculating optimal block size and number...")
|
||||
num_blocks, block_size = compute_optimal_blocks(
|
||||
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
|
||||
)
|
||||
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
|
||||
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
if tp_size is not None and tp_size > 1:
|
||||
if num_key_value_heads % tp_size != 0:
|
||||
@ -182,35 +203,8 @@ class PagedAttentionCache:
|
||||
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
||||
)
|
||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||
# self.num_key_value_heads //= tp_size
|
||||
num_key_value_heads //= tp_size
|
||||
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# Calculate optimal block size and number if not provided
|
||||
num_blocks = getattr(generation_config, "num_blocks", 1024)
|
||||
block_size = getattr(generation_config, "block_size", 32)
|
||||
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
|
||||
max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256)
|
||||
if num_blocks is None or max_batch_tokens is None:
|
||||
num_blocks, max_batch_tokens = compute_optimal_blocks(
|
||||
generation_config.max_new_tokens,
|
||||
block_size=block_size,
|
||||
head_dim=self.head_dim,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_key_value_heads,
|
||||
max_memory_percent=max_memory_percent,
|
||||
dtype=dtype,
|
||||
num_blocks=num_blocks,
|
||||
)
|
||||
logger.warning(
|
||||
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
|
||||
)
|
||||
self.max_batch_tokens = max_batch_tokens
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
||||
|
||||
self.dtype = dtype
|
||||
@ -255,7 +249,7 @@ class PagedAttentionCache:
|
||||
blocks_to_free = self._block_tables.pop(request_id)
|
||||
self._free_blocks.extend(blocks_to_free)
|
||||
else:
|
||||
logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
||||
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Returns the number of free blocks available."""
|
||||
@ -349,7 +343,7 @@ class Scheduler(ABC):
|
||||
@traced
|
||||
def has_pending_requests(self) -> bool:
|
||||
"""Check if there are requests ready to be processed."""
|
||||
return len(self.active_requests) or len(self.waiting_requests)
|
||||
return self.active_requests or self.waiting_requests
|
||||
|
||||
@abstractmethod
|
||||
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
||||
@ -601,60 +595,94 @@ class PrefillFirstScheduler(Scheduler):
|
||||
del self.active_requests[request_id]
|
||||
|
||||
|
||||
def get_device_and_memory():
|
||||
# Select best available device
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.cuda.memory_reserved(device)
|
||||
allocated_memory = torch.cuda.memory_allocated(device)
|
||||
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = torch.device("mps")
|
||||
# MPS memory reporting (PyTorch 2.0+)
|
||||
total_memory = torch.mps.driver_allocated_memory()
|
||||
allocated_memory = total_memory - torch.mps.recommended_max_memory()
|
||||
reserved_memory = 0 # MPS does not track reserved separately
|
||||
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
total_memory = None
|
||||
reserved_memory = 0
|
||||
allocated_memory = 0
|
||||
|
||||
return device, total_memory, reserved_memory, allocated_memory
|
||||
|
||||
|
||||
@traced(standalone=True)
|
||||
def compute_optimal_blocks(
|
||||
max_num_tokens,
|
||||
block_size,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_layers,
|
||||
max_memory_percent=0.9,
|
||||
num_blocks=None,
|
||||
dtype=torch.float16,
|
||||
device: torch.device,
|
||||
config: PretrainedConfig,
|
||||
generation_config: GenerationConfig,
|
||||
inputs: list[list[int]],
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
safety_margin: float = 0.9,
|
||||
median_prefill_length: Optional[int] = None,
|
||||
):
|
||||
device, total, reserved, allocated = get_device_and_memory()
|
||||
available_memory = int((total - max(allocated, reserved)) * max_memory_percent)
|
||||
"""Calculate optimal number and size of blocks for the KV cache.
|
||||
|
||||
Args:
|
||||
device: The device where the model runs
|
||||
config: The model configuration
|
||||
generation_config: The generation configuration
|
||||
inputs: Sample input sequences to estimate memory requirements
|
||||
dtype: Data type for cache tensors
|
||||
safety_margin: Fraction of available memory to use
|
||||
median_prefill_length: Override for median prefill length calculation
|
||||
|
||||
Returns:
|
||||
Tuple of (num_blocks, block_size)
|
||||
"""
|
||||
# Extract model dimensions
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
||||
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
|
||||
|
||||
# Get available device memory
|
||||
if device.type == "cuda":
|
||||
device_properties = torch.cuda.get_device_properties(device)
|
||||
total_memory = device_properties.total_memory
|
||||
allocated_memory = torch.cuda.memory_allocated(device)
|
||||
reserved_memory = torch.cuda.memory_reserved(device)
|
||||
available_memory = total_memory - max(allocated_memory, reserved_memory)
|
||||
elif device.type == "mps":
|
||||
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
|
||||
return 2048, 256
|
||||
else:
|
||||
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
|
||||
return 32, 128
|
||||
|
||||
# Apply safety margin
|
||||
available_memory = int(available_memory * safety_margin)
|
||||
if available_memory <= 0:
|
||||
logger.warning("Not enough available memory. Using minimum configuration.")
|
||||
return 8, 128 # Minimum viable configuration
|
||||
|
||||
# Calculate memory per token
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers
|
||||
if num_blocks is not None:
|
||||
# TODO
|
||||
max_possible_concurrent_requests = num_blocks * bytes_per_token
|
||||
# FIXME: forgot to add the inintial prompt length in the mix....
|
||||
max_possible_concurrent_requests = int(
|
||||
available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4)
|
||||
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches
|
||||
|
||||
# Estimate sequence length requirements
|
||||
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20
|
||||
|
||||
if median_prefill_length is None and inputs:
|
||||
non_empty_inputs = [len(seq) for seq in inputs if seq]
|
||||
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
|
||||
elif median_prefill_length is None:
|
||||
median_prefill_length = 64 # Reasonable default if no inputs provided
|
||||
|
||||
# Total sequence length including generated tokens
|
||||
seq_length = median_prefill_length + tokens_to_generate
|
||||
|
||||
# Calculate block parameters
|
||||
MIN_BLOCK_SIZE = 16
|
||||
|
||||
# Estimate number of concurrent sequences
|
||||
per_sequence_memory = seq_length * memory_per_token
|
||||
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
|
||||
|
||||
# Total tokens that can fit in memory
|
||||
total_tokens = available_memory // memory_per_token
|
||||
|
||||
# Calculate block size (rounded to power of 2)
|
||||
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
|
||||
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
|
||||
|
||||
# Calculate number of blocks
|
||||
num_blocks = max(1, total_tokens // block_size)
|
||||
|
||||
logger.info(
|
||||
f"Optimal cache: {num_blocks} blocks of size {block_size} "
|
||||
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
|
||||
)
|
||||
if max_possible_concurrent_requests <= 0:
|
||||
logger.warning("you are trying to generate a bit too many tokens")
|
||||
max_possible_concurrent_requests = 32
|
||||
max_concurrent_tokens = min(64, max_possible_concurrent_requests)
|
||||
# FIXME: Optimal means uses all memory
|
||||
optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64)
|
||||
return optimal_num_blocks, max_concurrent_tokens
|
||||
|
||||
return int(num_blocks), int(block_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -747,13 +775,15 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.requests_in_batch: list[RequestState] = []
|
||||
|
||||
# Get batch size parameters from generation config
|
||||
self._configure_batch_parameters()
|
||||
|
||||
# Set up metrics collector
|
||||
self.max_batch_tokens = cache.max_batch_tokens
|
||||
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
|
||||
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens)
|
||||
|
||||
self.setup_static_tensors()
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path)
|
||||
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced(standalone=True)
|
||||
@ -817,6 +847,25 @@ class ContinuousBatchProcessor:
|
||||
+ self.get_model_kwargs().__repr__()
|
||||
)
|
||||
|
||||
@traced(standalone=True)
|
||||
def _configure_batch_parameters(self):
|
||||
"""Set up batch processing parameters based on generation config."""
|
||||
# Calculate total cache capacity
|
||||
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
|
||||
|
||||
# Get or calculate max tokens per batch
|
||||
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
|
||||
if user_batch_tokens is not None:
|
||||
self.max_batch_tokens = user_batch_tokens
|
||||
else:
|
||||
# Default to 1/8 of total cache capacity, adjusted for context
|
||||
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
||||
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
|
||||
self.max_batch_tokens = max(64, recommended_batch_size)
|
||||
|
||||
# Context length and EOS token
|
||||
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
||||
|
||||
@traced
|
||||
def _get_new_requests(self):
|
||||
"""Pull new requests from the input queue and add to waiting list."""
|
||||
@ -962,14 +1011,7 @@ class ContinuousBatchProcessor:
|
||||
|
||||
@traced
|
||||
def _sync(self):
|
||||
if self.output_ids is not None:
|
||||
try:
|
||||
out = self.output_ids.tolist()[0] # should be the only synch we do
|
||||
except Exception:
|
||||
out = [0, 1]
|
||||
else:
|
||||
out = [0, 0]
|
||||
return out
|
||||
return self.output_ids.tolist()[0] # should be the only synch we do
|
||||
|
||||
@traced
|
||||
def _maybe_send_output(self, state: RequestState, token: int):
|
||||
@ -999,8 +1041,6 @@ class ContinuousBatchProcessor:
|
||||
self._maybe_send_output(state, token)
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
raise ValueError("No more free blocks")
|
||||
|
||||
@traced
|
||||
def has_pending_requests(self) -> bool:
|
||||
@ -1022,9 +1062,7 @@ class ContinuousBatchProcessor:
|
||||
Args:
|
||||
error: The error to report in the failure message
|
||||
"""
|
||||
|
||||
requests = list(self.scheduler.active_requests.values())
|
||||
for state in requests:
|
||||
for state in self.scheduler.active_requests.values():
|
||||
self._handle_request_error(error, state)
|
||||
self.scheduler.finish_request(state.request_id)
|
||||
|
||||
@ -1068,8 +1106,7 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
||||
streaming: Whether to stream tokens as they are generated
|
||||
"""
|
||||
self.model = model.eval()
|
||||
generation_config = model.generation_config if generation_config is None else generation_config
|
||||
self.model = model
|
||||
self.generation_config = generation_config
|
||||
self.input_queue = queue.Queue(maxsize=max_queue_size)
|
||||
self.output_queue = queue.Queue()
|
||||
@ -1081,6 +1118,7 @@ class ContinuousBatchingManager:
|
||||
self._request_lock = threading.Lock()
|
||||
self.model.generation_config.top_p = None
|
||||
self.do_sample = getattr(generation_config, "do_sample", True)
|
||||
generation_config = model.generation_config if generation_config is None else generation_config
|
||||
self.logit_processor = self.model._get_logits_processor(generation_config)
|
||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
@ -1204,7 +1242,7 @@ class ContinuousBatchingManager:
|
||||
|
||||
@traced
|
||||
def warmup(self, batch_processor):
|
||||
stream = torch.cuda.Stream(device=self.model.device)
|
||||
stream = torch.cuda.Stream()
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
# Warmup the model with a dummy forward pass
|
||||
@ -1212,7 +1250,7 @@ class ContinuousBatchingManager:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, stream=stream):
|
||||
with torch.cuda.graph(self.graph):
|
||||
self._generation_step(batch_processor)
|
||||
|
||||
@traced
|
||||
@ -1258,8 +1296,7 @@ class ContinuousBatchingManager:
|
||||
self.generation_config,
|
||||
self.model.device,
|
||||
self.model.dtype,
|
||||
num_requests=len(self.input_queue.queue),
|
||||
tp_size=getattr(self.model, "_tp_size", 8), # TODO quantized converted don't set this
|
||||
tp_size=getattr(self.model, "tp_size"),
|
||||
)
|
||||
|
||||
scheduler = None
|
||||
@ -1287,10 +1324,33 @@ class ContinuousBatchingManager:
|
||||
)
|
||||
self.batch_processor = batch_processor
|
||||
is_first = True
|
||||
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
|
||||
self._inner_generation_loop(batch_processor, is_first)
|
||||
if is_first:
|
||||
is_first = False
|
||||
|
||||
if self.profile:
|
||||
tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1)
|
||||
trace_handler = tensorboard_trace_handler(
|
||||
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
|
||||
)
|
||||
activities = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
with profile(
|
||||
activities=activities,
|
||||
schedule=tracing_schedule,
|
||||
on_trace_ready=trace_handler,
|
||||
record_shapes=False,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
||||
self._inner_generation_loop(batch_processor, is_first)
|
||||
if is_first:
|
||||
is_first = False
|
||||
prof.step()
|
||||
else:
|
||||
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
||||
self._inner_generation_loop(batch_processor, is_first)
|
||||
if is_first:
|
||||
is_first = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generation loop: {e}", exc_info=True)
|
||||
@ -1303,8 +1363,6 @@ class ContinuousBatchingManager:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
batch_processor.prepare_next_batch()
|
||||
device, total, reserved, allocated = get_device_and_memory()
|
||||
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
|
||||
if torch.cuda.is_available() and self.use_cuda_graph:
|
||||
if is_first:
|
||||
self.warmup(batch_processor)
|
||||
@ -1444,7 +1502,6 @@ class ContinuousMixin:
|
||||
results[req_id] = result
|
||||
finished_count += 1
|
||||
pbar.update(1)
|
||||
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
|
||||
else:
|
||||
if not manager.is_running():
|
||||
logger.error("Generation thread terminated unexpectedly.")
|
||||
|
||||
@ -435,7 +435,9 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
# create banned_tokens boolean mask
|
||||
banned_tokens_indices_mask = []
|
||||
for banned_tokens_slice in banned_tokens:
|
||||
banned_tokens_indices_mask.append([token in banned_tokens_slice for token in range(vocab_size)])
|
||||
banned_tokens_indices_mask.append(
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
||||
|
||||
|
||||
@ -1800,7 +1800,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
|
||||
return model_kwargs
|
||||
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
||||
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
||||
|
||||
@ -833,7 +833,7 @@ class ImageFeatureExtractionMixin:
|
||||
return image.crop((left, top, right, bottom))
|
||||
|
||||
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
|
||||
channel_first = image.shape[0] in [1, 3]
|
||||
channel_first = True if image.shape[0] in [1, 3] else False
|
||||
|
||||
# Transpose (height, width, n_channels) format images
|
||||
if not channel_first:
|
||||
|
||||
@ -119,14 +119,6 @@ _import_structure = {
|
||||
"run_hp_search_sigopt",
|
||||
"run_hp_search_wandb",
|
||||
],
|
||||
"mxfp4": [
|
||||
"replace_with_mxfp4_linear",
|
||||
"Mxfp4GptOssExperts",
|
||||
"quantize_to_mxfp4",
|
||||
"convert_moe_packed_tensors",
|
||||
"dequantize",
|
||||
"load_and_swizzle_mxfp4",
|
||||
],
|
||||
"peft": ["PeftAdapterMixin"],
|
||||
"quanto": ["replace_with_quanto_layers"],
|
||||
"spqr": ["replace_with_spqr_linear"],
|
||||
@ -263,13 +255,6 @@ if TYPE_CHECKING:
|
||||
run_hp_search_sigopt,
|
||||
run_hp_search_wandb,
|
||||
)
|
||||
from .mxfp4 import (
|
||||
Mxfp4GptOssExperts,
|
||||
dequantize,
|
||||
load_and_swizzle_mxfp4,
|
||||
quantize_to_mxfp4,
|
||||
replace_with_mxfp4_linear,
|
||||
)
|
||||
from .peft import PeftAdapterMixin
|
||||
from .quanto import replace_with_quanto_layers
|
||||
from .spqr import replace_with_spqr_linear
|
||||
|
||||
@ -4,11 +4,8 @@ from ..generation.continuous_batching import PagedAttentionCache
|
||||
from ..utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
try:
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
except Exception:
|
||||
pass
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
|
||||
|
||||
def paged_attention_forward(
|
||||
@ -50,10 +47,8 @@ def paged_attention_forward(
|
||||
"""
|
||||
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
|
||||
|
||||
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
|
||||
if implementation is not None:
|
||||
flash_attn_varlen_func = implementation.flash_attn_varlen_func
|
||||
custom_kwargs = {"s_aux": kwargs.get("s_aux")}
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q.transpose(1, 2).squeeze(0).contiguous(),
|
||||
k.transpose(1, 2).squeeze(0).contiguous(),
|
||||
@ -64,10 +59,9 @@ def paged_attention_forward(
|
||||
max_seqlen_k,
|
||||
softmax_scale=module.scaling,
|
||||
causal=True, # kind of a must, it automatically aligns the mask for q < k
|
||||
window_size=sliding_window, # -1 means infinite context window
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
# block_table=block_tables, -> torch.Tensor
|
||||
**custom_kwargs,
|
||||
# **kwargs,
|
||||
)
|
||||
if isinstance(attn_output, tuple):
|
||||
attn_output = attn_output[0]
|
||||
|
||||
return attn_output, None
|
||||
|
||||
@ -198,8 +198,8 @@ def make_flex_block_causal_mask(
|
||||
mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0].to(device)
|
||||
kv_offset = offsets[1].to(device)
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
@ -241,7 +241,6 @@ def flex_attention_forward(
|
||||
scaling: Optional[float] = None,
|
||||
softcap: Optional[float] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
s_aux: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if head_mask is not None:
|
||||
@ -272,12 +271,6 @@ def flex_attention_forward(
|
||||
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
|
||||
if head_mask is not None:
|
||||
score = score + head_mask[batch_idx][head_idx][0][0]
|
||||
if s_aux is not None:
|
||||
logits_max = torch.max(score, dim=-1, keepdim=True).values
|
||||
sinks = torch.exp(s_aux - logits_max)
|
||||
unnormalized_scores = torch.exp(score - logits_max)
|
||||
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
|
||||
score = unnormalized_scores / normalizer
|
||||
return score
|
||||
|
||||
enable_gqa = True
|
||||
|
||||
@ -18,7 +18,6 @@ try:
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
@ -45,14 +44,7 @@ try:
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
# revision="pure-layer-test",
|
||||
),
|
||||
"rocm": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
# revision="pure-layer-test",
|
||||
)
|
||||
},
|
||||
)
|
||||
},
|
||||
"MLP": {
|
||||
"cuda": LayerRepository(
|
||||
@ -61,22 +53,10 @@ try:
|
||||
)
|
||||
},
|
||||
"MegaBlocksMoeMLP": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-community/megablocks",
|
||||
layer_name="MegaBlocksMoeMLP",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/megablocks",
|
||||
layer_name="MegaBlocksMoeMLP",
|
||||
),
|
||||
},
|
||||
"rocm": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="ahadnagy/megablocks",
|
||||
layer_name="MegaBlocksMoeMLP",
|
||||
)
|
||||
},
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/megablocks",
|
||||
layer_name="MegaBlocksMoeMLP",
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if os.getenv("WANDB_MODE") == "offline":
|
||||
print("⚙️ Running in WANDB offline mode")
|
||||
@ -1045,14 +1043,6 @@ class WandbCallback(TrainerCallback):
|
||||
class TrackioCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that logs metrics to Trackio.
|
||||
|
||||
It records training metrics, model (and PEFT) configuration, and GPU memory usage.
|
||||
If `nvidia-ml-py` is installed, GPU power consumption is also tracked.
|
||||
|
||||
**Requires**:
|
||||
```bash
|
||||
pip install trackio
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -1129,14 +1119,12 @@ class TrackioCallback(TrainerCallback):
|
||||
device_idx = torch.cuda.current_device()
|
||||
total_memory = torch.cuda.get_device_properties(device_idx).total_memory
|
||||
memory_allocated = torch.cuda.memory_allocated(device_idx)
|
||||
|
||||
power = torch.cuda.power_draw(device_idx)
|
||||
gpu_memory_logs = {
|
||||
f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB
|
||||
f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio
|
||||
f"gpu/{device_idx}/power": power / 1000, # Watts
|
||||
}
|
||||
if _is_package_available("pynvml"):
|
||||
power = torch.cuda.power_draw(device_idx)
|
||||
gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
gathered_logs = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(gathered_logs, gpu_memory_logs)
|
||||
|
||||
@ -1,470 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
import re
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
FP4_VALUES = [
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
|
||||
|
||||
# Copied from GPT_OSS repo and vllm
|
||||
def quantize_to_mxfp4(w):
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
|
||||
|
||||
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
|
||||
w, w_scale = swizzle_mxfp4(w, w_scale)
|
||||
return w, w_scale
|
||||
|
||||
|
||||
def swizzle_mxfp4(w, w_scale):
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
|
||||
# TODO : add that when we are actually sure that it works on B200
|
||||
# if torch.cuda.get_device_capability()[0] == 10:
|
||||
# constraints = {
|
||||
# "is_persistent": True,
|
||||
# "epilogue_subtile": 1,
|
||||
# }
|
||||
# opt_flags.update_opt_flags_constraints(constraints)
|
||||
# # transpose the tensor so that the quantization axis is on dim1
|
||||
|
||||
# TODO: there is still an issue with the scales on hopper
|
||||
# scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=8)
|
||||
# w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts)
|
||||
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
|
||||
return w, w_scale
|
||||
|
||||
|
||||
# Copied from GPT_OSS repo
|
||||
# TODO: Add absolute link when the repo is public
|
||||
def convert_moe_packed_tensors(
|
||||
blocks,
|
||||
scales,
|
||||
*,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
rows_per_chunk: int = 32768 * 1024,
|
||||
) -> torch.Tensor:
|
||||
import math
|
||||
|
||||
# Check if blocks and scales are on CPU, and move to GPU if so
|
||||
if not blocks.is_cuda and torch.cuda.is_available():
|
||||
blocks = blocks.cuda()
|
||||
scales = scales.cuda()
|
||||
|
||||
scales = scales.to(torch.int32) - 127
|
||||
|
||||
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"
|
||||
|
||||
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
|
||||
|
||||
*prefix_shape, G, B = blocks.shape
|
||||
rows_total = math.prod(prefix_shape) * G
|
||||
|
||||
blocks = blocks.reshape(rows_total, B)
|
||||
scales = scales.reshape(rows_total, 1)
|
||||
|
||||
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
|
||||
|
||||
for r0 in range(0, rows_total, rows_per_chunk):
|
||||
r1 = min(r0 + rows_per_chunk, rows_total)
|
||||
|
||||
blk = blocks[r0:r1]
|
||||
exp = scales[r0:r1]
|
||||
|
||||
# nibble indices -> int64
|
||||
idx_lo = (blk & 0x0F).to(torch.long)
|
||||
idx_hi = (blk >> 4).to(torch.long)
|
||||
|
||||
sub = out[r0:r1]
|
||||
sub[:, 0::2] = lut[idx_lo]
|
||||
sub[:, 1::2] = lut[idx_hi]
|
||||
|
||||
torch.ldexp(sub, exp, out=sub)
|
||||
del idx_lo, idx_hi, blk, exp, sub
|
||||
|
||||
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
|
||||
|
||||
# TODO: Delete after making sure this is not necessary! since we go back to cpu in the end in create_quantized_param using .to(target_device)
|
||||
# Move back to CPU if needed
|
||||
# if need_to_move_back:
|
||||
# out = out.cpu()
|
||||
del blocks, scales, lut
|
||||
return out
|
||||
|
||||
|
||||
class Mxfp4GptOssExperts(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = config.num_local_experts
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.gate_up_proj_blocks = nn.Parameter(
|
||||
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_scales = nn.Parameter(
|
||||
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_bias = nn.Parameter(
|
||||
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
self.down_proj_blocks = nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_scales = nn.Parameter(
|
||||
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_bias = nn.Parameter(
|
||||
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
self.alpha = 1.702
|
||||
|
||||
self.gate_up_proj_precision_config = None
|
||||
self.down_proj_precision_config = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
|
||||
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
|
||||
with torch.cuda.device(hidden_states.device):
|
||||
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
self.gate_up_proj,
|
||||
self.gate_up_proj_bias.to(torch.float32),
|
||||
routing_data,
|
||||
gather_indx=gather_idx,
|
||||
precision_config=self.gate_up_proj_precision_config,
|
||||
gammas=None,
|
||||
fused_activation=act,
|
||||
)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
self.down_proj,
|
||||
self.down_proj_bias.to(torch.float32),
|
||||
routing_data,
|
||||
scatter_indx=scatter_idx,
|
||||
precision_config=self.down_proj_precision_config,
|
||||
gammas=routing_data.gate_scal,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
# Adapted from GPT_OSS repo
|
||||
# TODO: Add absolute link when the repo is public
|
||||
def routing_torch_dist(
|
||||
logits,
|
||||
n_expts_act,
|
||||
):
|
||||
import os
|
||||
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
|
||||
|
||||
with torch.cuda.device(logits.device):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
replace_value = -1
|
||||
|
||||
n_tokens = logits.shape[0]
|
||||
n_expts_tot = logits.shape[1]
|
||||
|
||||
n_local_experts = n_expts_tot // world_size
|
||||
local_expert_start = rank * n_local_experts
|
||||
local_expert_end = (rank + 1) * n_local_experts
|
||||
|
||||
n_gates_pad = n_tokens * n_expts_act
|
||||
|
||||
def topk(vals, k):
|
||||
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
|
||||
tk_indx = tk_indx.long()
|
||||
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
|
||||
return tk_val, tk_indx.int()
|
||||
|
||||
expt_scal, expt_indx = topk(logits, n_expts_act)
|
||||
expt_scal = torch.softmax(expt_scal, dim=-1)
|
||||
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
|
||||
expt_scal = torch.gather(expt_scal, 1, sort_indices)
|
||||
|
||||
# Flatten and mask for local experts
|
||||
expt_scal = expt_scal.reshape(-1)
|
||||
|
||||
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
|
||||
|
||||
expt_indx = expt_indx.view(-1).to(torch.int32)
|
||||
|
||||
# we use a large value to replace the indices that are not in the local expert range
|
||||
var = 1000
|
||||
expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
|
||||
topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
|
||||
gate_indx = torch.argsort(topk_indx).to(torch.int32)
|
||||
expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
|
||||
expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
|
||||
|
||||
gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
|
||||
gate_scal = expt_scal[topk_indx]
|
||||
|
||||
topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
|
||||
|
||||
# # Routing metadata for local expert computation
|
||||
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
|
||||
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
|
||||
|
||||
expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
|
||||
|
||||
hitted_experts = n_expts_act
|
||||
return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx
|
||||
|
||||
|
||||
def mlp_forward(self, hidden_states):
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
routing = routing_torch_dist
|
||||
else:
|
||||
from triton_kernels.routing import routing
|
||||
|
||||
routing = routing
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
|
||||
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
|
||||
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
||||
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
|
||||
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
|
||||
return routed_out, router_logits
|
||||
|
||||
|
||||
def should_convert_module(current_key_name, patterns):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
|
||||
from ..integrations.tensor_parallel import shard_and_distribute_module
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
empty_param = kwargs.get("empty_param", None)
|
||||
casting_dtype = kwargs.get("casting_dtype", None)
|
||||
to_contiguous = kwargs.get("to_contiguous", None)
|
||||
rank = kwargs.get("rank", None)
|
||||
device_mesh = kwargs.get("device_mesh", None)
|
||||
|
||||
for proj in ["gate_up_proj", "down_proj"]:
|
||||
if proj in param_name:
|
||||
if device_mesh is not None:
|
||||
param_value = shard_and_distribute_module(
|
||||
model,
|
||||
param_value,
|
||||
empty_param,
|
||||
dq_param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
rank,
|
||||
device_mesh,
|
||||
set_param=False,
|
||||
)
|
||||
blocks_attr = f"{proj}_blocks"
|
||||
scales_attr = f"{proj}_scales"
|
||||
setattr(module, param_name.rsplit(".", 1)[1], param_value)
|
||||
if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
|
||||
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
|
||||
dequantized = dequantized.transpose(1, 2).contiguous().to(target_device)
|
||||
# TODO: this is perhaps necessary since if target_device is cpu, and the param was on gpu
|
||||
if target_device == "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
setattr(module, proj, torch.nn.Parameter(dequantized))
|
||||
delattr(module, blocks_attr)
|
||||
delattr(module, scales_attr)
|
||||
|
||||
|
||||
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
|
||||
|
||||
from ..integrations.tensor_parallel import shard_and_distribute_module
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
empty_param = kwargs.get("empty_param", None)
|
||||
casting_dtype = kwargs.get("casting_dtype", None)
|
||||
to_contiguous = kwargs.get("to_contiguous", None)
|
||||
rank = kwargs.get("rank", None)
|
||||
device_mesh = kwargs.get("device_mesh", None)
|
||||
|
||||
for proj in ["gate_up_proj", "down_proj"]:
|
||||
if proj in param_name:
|
||||
if device_mesh is not None:
|
||||
shard_and_distribute_module(
|
||||
model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
|
||||
)
|
||||
else:
|
||||
setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
|
||||
blocks_attr = f"{proj}_blocks"
|
||||
scales_attr = f"{proj}_scales"
|
||||
blocks = getattr(module, blocks_attr)
|
||||
scales = getattr(module, scales_attr)
|
||||
# Check if both blocks and scales both not on on meta device
|
||||
if blocks.device.type != "meta" and scales.device.type != "meta":
|
||||
# need it for ep
|
||||
local_experts = blocks.size(0)
|
||||
if proj == "gate_up_proj":
|
||||
blocks = blocks.view(local_experts, module.intermediate_size * 2, -1)
|
||||
else:
|
||||
blocks = blocks.view(local_experts, -1, module.intermediate_size // 2)
|
||||
# TODO: we need to have the weights on cuda, refactor later
|
||||
if getattr(target_device, "type", target_device) == "cpu":
|
||||
target_device = "cuda"
|
||||
# TODO: check why we still do move the tensors despite the context manager
|
||||
blocks = blocks.to(target_device)
|
||||
scales = scales.to(target_device)
|
||||
with torch.cuda.device(target_device):
|
||||
triton_weight_tensor, weight_scale = swizzle_mxfp4(
|
||||
blocks.transpose(-2, -1), scales.transpose(-2, -1)
|
||||
)
|
||||
|
||||
# need to overwrite the shapes for the kernels
|
||||
if proj == "gate_up_proj":
|
||||
triton_weight_tensor.shape = torch.Size(
|
||||
[local_experts, module.hidden_size, module.intermediate_size * 2]
|
||||
)
|
||||
else:
|
||||
triton_weight_tensor.shape = torch.Size(
|
||||
[local_experts, module.intermediate_size, module.hidden_size]
|
||||
)
|
||||
|
||||
# triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
|
||||
setattr(module, proj, triton_weight_tensor)
|
||||
setattr(
|
||||
module,
|
||||
f"{proj}_precision_config",
|
||||
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
|
||||
)
|
||||
|
||||
# delete blocks and scales
|
||||
delattr(module, scales_attr)
|
||||
delattr(module, blocks_attr)
|
||||
# setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False))
|
||||
del blocks
|
||||
|
||||
|
||||
def _replace_with_mxfp4_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
config=None,
|
||||
):
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
if not should_convert_module(current_key_name, modules_to_not_convert):
|
||||
current_key_name.pop(-1)
|
||||
continue
|
||||
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
|
||||
with init_empty_weights():
|
||||
model._modules[name] = Mxfp4GptOssExperts(config)
|
||||
has_been_replaced = True
|
||||
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
|
||||
from types import MethodType
|
||||
|
||||
module.forward = MethodType(mlp_forward, module)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_mxfp4_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
config=config,
|
||||
)
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_mxfp4_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
config=None,
|
||||
):
|
||||
if quantization_config.dequantize:
|
||||
return model
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
model, has_been_replaced = _replace_with_mxfp4_linear(
|
||||
model,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
config=config,
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
@ -657,7 +657,7 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if hasattr(mod, "bias") and mod.bias is not None:
|
||||
mod._bias = mod.bias.to_local()
|
||||
mod._bias = mod.bias
|
||||
mod.bias = None
|
||||
|
||||
input_tensor = inputs[0]
|
||||
@ -675,11 +675,10 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
# 2. to shard -> reduce_scatter
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
outputs = outputs.to_local()
|
||||
if hasattr(mod, "_bias"):
|
||||
outputs += mod._bias
|
||||
# back to local tensor if use_local_output is True
|
||||
return outputs
|
||||
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
module._distribute_module_applied = True
|
||||
@ -997,7 +996,7 @@ def add_tensor_parallel_hooks_to_module(
|
||||
|
||||
|
||||
def shard_and_distribute_module(
|
||||
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param=True
|
||||
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
): # TODO: rename to shard_and_distribute_param
|
||||
r"""
|
||||
This function is called in `from_pretrained` when loading a model's checkpoints.
|
||||
|
||||
@ -158,7 +158,6 @@ LOSS_MAPPING = {
|
||||
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
|
||||
"MMGroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
|
||||
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
||||
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
|
||||
@ -48,7 +48,7 @@ def and_masks(*mask_functions: list[Callable]) -> Callable:
|
||||
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
|
||||
result = q_idx.new_ones((), dtype=torch.bool)
|
||||
for mask in mask_functions:
|
||||
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
|
||||
result = result & mask(batch_idx, head_idx, q_idx, kv_idx)
|
||||
return result
|
||||
|
||||
return and_mask
|
||||
@ -62,7 +62,7 @@ def or_masks(*mask_functions: list[Callable]) -> Callable:
|
||||
def or_mask(batch_idx, head_idx, q_idx, kv_idx):
|
||||
result = q_idx.new_zeros((), dtype=torch.bool)
|
||||
for mask in mask_functions:
|
||||
result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
|
||||
result = result | mask(batch_idx, head_idx, q_idx, kv_idx)
|
||||
return result
|
||||
|
||||
return or_mask
|
||||
|
||||
@ -389,8 +389,7 @@ def _flash_attention_forward(
|
||||
flash_kwargs["deterministic"] = det
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
if "s_aux" in kwargs:
|
||||
flash_kwargs["s_aux"] = kwargs.get("s_aux")
|
||||
|
||||
query_states, key_states, value_states = fa_peft_integration_check(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
|
||||
@ -252,13 +252,10 @@ def _compute_yarn_parameters(
|
||||
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
||||
"""Find dimension range bounds based on rotations"""
|
||||
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
if truncate:
|
||||
low = low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
@ -275,8 +272,7 @@ def _compute_yarn_parameters(
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
truncate = config.rope_scaling.get("truncate", True)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||
|
||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
|
||||
@ -469,7 +465,6 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
|
||||
"original_max_position_embeddings",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
"truncate",
|
||||
}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
@ -513,13 +508,13 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona
|
||||
short_factor = rope_scaling.get("short_factor")
|
||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
||||
if len(short_factor) != dim // 2:
|
||||
if not len(short_factor) == dim // 2:
|
||||
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
||||
|
||||
long_factor = rope_scaling.get("long_factor")
|
||||
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
||||
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
||||
if len(long_factor) != dim // 2:
|
||||
if not len(long_factor) == dim // 2:
|
||||
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
||||
|
||||
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
||||
|
||||
@ -51,7 +51,6 @@ if is_torchao_available():
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .distributed import DistributedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
@ -710,7 +709,6 @@ def _infer_parameter_dtype(
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
QuantizationMethod.MXFP4,
|
||||
}:
|
||||
return True, None
|
||||
else:
|
||||
@ -780,8 +778,9 @@ def _load_state_dict_into_meta_model(
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
|
||||
for param_name, empty_param in state_dict.items():
|
||||
if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
|
||||
if param_name not in expected_keys:
|
||||
continue
|
||||
|
||||
# we need to use serialized_param_name as file pointer is untouched
|
||||
if is_meta_state_dict:
|
||||
# This is the name of the parameter as it appears on disk file
|
||||
@ -789,6 +788,7 @@ def _load_state_dict_into_meta_model(
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
else:
|
||||
param = empty_param.to(tensor_device) # It is actually not empty!
|
||||
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
@ -797,47 +797,17 @@ def _load_state_dict_into_meta_model(
|
||||
hf_quantizer,
|
||||
)
|
||||
|
||||
if device_mesh is not None:
|
||||
if (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (
|
||||
not hf_quantizer.check_quantized_param(
|
||||
model,
|
||||
param,
|
||||
param_name,
|
||||
state_dict,
|
||||
device_map=device_map,
|
||||
)
|
||||
)
|
||||
): # In this case, the param is already on the correct device!
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
empty_param,
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
|
||||
sharding_kwargs = {
|
||||
"empty_param": empty_param,
|
||||
"casting_dtype": casting_dtype,
|
||||
"to_contiguous": to_contiguous,
|
||||
"rank": device_mesh.get_local_rank(),
|
||||
"device_mesh": device_mesh,
|
||||
}
|
||||
hf_quantizer.create_quantized_param(
|
||||
model,
|
||||
param,
|
||||
param_name,
|
||||
device_mesh.get_local_rank(),
|
||||
state_dict,
|
||||
unexpected_keys,
|
||||
**sharding_kwargs,
|
||||
)
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
empty_param,
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
param = param[...]
|
||||
if casting_dtype is not None:
|
||||
@ -882,24 +852,17 @@ def _load_state_dict_into_meta_model(
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys
|
||||
)
|
||||
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
param_name = hf_quantizer.update_param_name(param_name)
|
||||
module, param_type = get_module_from_name(model, param_name)
|
||||
value = getattr(module, param_type)
|
||||
# special case for GptOssForCausalLM, we wait for the param to be leave the meta device before casting it to cpu
|
||||
if model.__class__.__name__ == "GptOssForCausalLM" and value.device.type == "meta":
|
||||
continue
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
param_to = "meta"
|
||||
val_kwargs = {}
|
||||
if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (
|
||||
value.dtype == torch.uint8 or value.dtype == torch.int8
|
||||
):
|
||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, param_type, value)
|
||||
@ -2636,7 +2599,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
if not self._supports_sdpa and not is_init_check:
|
||||
if not self._supports_sdpa:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
|
||||
@ -2720,51 +2683,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
attention_wrapper = None
|
||||
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
|
||||
if "|" in applicable_attn_implementation:
|
||||
attention_wrapper, applicable_attn_implementation = applicable_attn_implementation.split("|")
|
||||
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
|
||||
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
if ":" in applicable_attn_implementation:
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
else:
|
||||
repo_id = applicable_attn_implementation
|
||||
repo_id = attn_implementation
|
||||
kernel_name = None
|
||||
repo_id = repo_id.strip()
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||
if attention_wrapper is None:
|
||||
attention_wrapper = flash_attention_forward
|
||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||
kernel_function = partial(flash_attention_forward, implementation=kernel)
|
||||
elif kernel_name is not None:
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(
|
||||
attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
)
|
||||
# Register it
|
||||
ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||
applicable_attn_implementation = repo_id
|
||||
except Exception as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
|
||||
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
return attn_implementation
|
||||
else:
|
||||
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||
|
||||
def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
|
||||
requested_attention = "sdpa" if _requested_attention is None else _requested_attention
|
||||
if is_init_check and requested_attention == "sdpa":
|
||||
if not self._supports_sdpa:
|
||||
requested_attention = "eager"
|
||||
if requested_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{requested_attention}"` is not supported. The only possible arguments are '
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
'`attn_implementation="eager"` (manual attention implementation)'
|
||||
)
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
@ -2780,21 +2726,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# Perform relevant checks
|
||||
if requested_attention == "flash_attention_2":
|
||||
if applicable_attn_implementation == "flash_attention_2":
|
||||
self._flash_attn_2_can_dispatch(is_init_check)
|
||||
elif requested_attention == "flash_attention_3":
|
||||
elif applicable_attn_implementation == "flash_attention_3":
|
||||
self._flash_attn_3_can_dispatch(is_init_check)
|
||||
elif requested_attention == "flex_attention":
|
||||
elif applicable_attn_implementation == "flex_attention":
|
||||
self._flex_attn_can_dispatch(is_init_check)
|
||||
elif requested_attention == "sdpa":
|
||||
elif applicable_attn_implementation == "sdpa":
|
||||
# Sdpa is the default, so we try it and fallback to eager otherwise when not possible
|
||||
try:
|
||||
self._sdpa_can_dispatch(is_init_check)
|
||||
except (ValueError, ImportError) as e:
|
||||
if _requested_attention == "sdpa":
|
||||
# In this case, sdpa was requested explicitly, but we can't use it, so let's raise
|
||||
if attn_implementation == "sdpa":
|
||||
raise e
|
||||
requested_attention = "eager"
|
||||
return requested_attention
|
||||
applicable_attn_implementation = "eager"
|
||||
|
||||
return applicable_attn_implementation
|
||||
|
||||
@classmethod
|
||||
def _can_set_attn_implementation(cls) -> bool:
|
||||
@ -2842,7 +2790,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
# Apply the change (on the internal attr, to avoid setting it recursively)
|
||||
self.config._attn_implementation_internal = applicable_attn_implementation
|
||||
except Exception as e:
|
||||
except (ValueError, ImportError) as e:
|
||||
logger.warning(
|
||||
f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}"
|
||||
)
|
||||
@ -2866,13 +2814,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
subconfig_key, submodule.config._attn_implementation
|
||||
)
|
||||
break
|
||||
# check the module can use correctly, otherwise we silently set the config without the model using it
|
||||
try:
|
||||
sub_implementation = submodule.get_correct_attn_implementation(sub_implementation)
|
||||
submodule.config._attn_implementation = sub_implementation
|
||||
subconfigs_changed.add(submodule.config.__class__)
|
||||
except Exception:
|
||||
pass
|
||||
submodule.set_attn_implementation(sub_implementation)
|
||||
subconfigs_changed.add(submodule.config.__class__)
|
||||
|
||||
# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
|
||||
for subconfig_key in self.config.sub_configs:
|
||||
@ -4615,6 +4558,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
distributed_config = kwargs.pop("distributed_config", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
@ -4624,7 +4568,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
distributed_config: DistributedConfig = kwargs.pop("distributed_config", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
use_kernels = kwargs.pop("use_kernels", False)
|
||||
@ -4986,6 +4929,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None:
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
|
||||
# Make sure to tie the weights correctly
|
||||
model.tie_weights()
|
||||
|
||||
@ -5014,7 +4960,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config, use_kernels=use_kernels
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
|
||||
)
|
||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||
# once the weights have been quantized
|
||||
@ -5031,9 +4977,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model)
|
||||
|
||||
if _torch_distributed_available and device_mesh is not None:
|
||||
model = distribute_model(model, distributed_config, device_mesh, tp_size)
|
||||
|
||||
# Prepare the full device map
|
||||
if device_map is not None:
|
||||
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_regex)
|
||||
@ -5080,10 +5023,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
# check if using kernels
|
||||
if use_kernels:
|
||||
|
||||
if not is_kernels_available():
|
||||
raise ValueError("Kernels are not available. To use kernels, please install kernels using `pip install kernels`")
|
||||
|
||||
from kernels import Device, kernelize
|
||||
|
||||
kernelize(model, device=Device(type=model.device.type))
|
||||
@ -5157,8 +5096,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
model.hf_quantizer = hf_quantizer
|
||||
hf_quantizer.postprocess_model(model, config=config)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if _adapter_model_path is not None:
|
||||
adapter_kwargs["key_mapping"] = key_mapping
|
||||
@ -5807,8 +5746,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Check if base model has a TP plan
|
||||
if getattr(self.base_model, "_tp_plan", None) is not None:
|
||||
return True
|
||||
if self.config.base_model_tp_plan is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
@ -6070,16 +6007,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
||||
if param_name in tied_param_names:
|
||||
continue
|
||||
|
||||
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
|
||||
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
|
||||
if hf_quantizer is not None:
|
||||
param_name = hf_quantizer.update_param_name(param_name)
|
||||
|
||||
try:
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Parameter {param_name} not found in model")
|
||||
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
|
||||
|
||||
@ -63,7 +63,6 @@ if TYPE_CHECKING:
|
||||
from .codegen import *
|
||||
from .cohere import *
|
||||
from .cohere2 import *
|
||||
from .cohere2_vision import *
|
||||
from .colpali import *
|
||||
from .colqwen2 import *
|
||||
from .conditional_detr import *
|
||||
@ -140,7 +139,6 @@ if TYPE_CHECKING:
|
||||
from .gpt_neo import *
|
||||
from .gpt_neox import *
|
||||
from .gpt_neox_japanese import *
|
||||
from .gpt_oss import *
|
||||
from .gpt_sw3 import *
|
||||
from .gptj import *
|
||||
from .granite import *
|
||||
|
||||
@ -275,7 +275,7 @@ def replace_params(hf_params, tf_params, key_mapping):
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" in key or "bn/beta" in key:
|
||||
elif "bn/gamma" or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
@ -394,7 +394,7 @@ class AlignVisionBlock(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.expand_ratio = expand_ratio
|
||||
self.expand = self.expand_ratio != 1
|
||||
self.expand = True if self.expand_ratio != 1 else False
|
||||
expand_in_dim = in_dim * expand_ratio
|
||||
|
||||
if self.expand:
|
||||
@ -464,10 +464,10 @@ class AlignVisionEncoder(nn.Module):
|
||||
expand_ratio = config.expand_ratios[i]
|
||||
|
||||
for j in range(round_repeats(config.num_block_repeats[i])):
|
||||
id_skip = j == 0
|
||||
id_skip = True if j == 0 else False
|
||||
stride = 1 if j > 0 else stride
|
||||
in_dim = out_dim if j > 0 else in_dim
|
||||
adjust_padding = curr_block_num not in config.depthwise_padding
|
||||
adjust_padding = False if curr_block_num in config.depthwise_padding else True
|
||||
drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
|
||||
|
||||
block = AlignVisionBlock(
|
||||
|
||||
@ -515,8 +515,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
`int`: Number of patches per image.
|
||||
"""
|
||||
split_image = images_kwargs.get("split_image", self.split_image)
|
||||
max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
|
||||
split_image = images_kwargs["split_image"] if "split_image" in images_kwargs else self.split_image
|
||||
max_image_size = images_kwargs["max_image_size"] if "max_image_size" in images_kwargs else self.max_image_size
|
||||
|
||||
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
||||
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
||||
|
||||
@ -978,30 +978,6 @@ class AriaModel(AriaPreTrainedModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
def get_placeholder_mask(
|
||||
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
||||
):
|
||||
"""
|
||||
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
||||
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
return special_image_mask
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -1031,15 +1007,29 @@ class AriaModel(AriaPreTrainedModel):
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
vision_feature_layer=self.config.vision_feature_layer,
|
||||
)
|
||||
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
|
||||
n_image_features = n_images * n_features_per_image
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self._get_image_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
@ -1116,7 +1106,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
return self.model.get_decoder
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
|
||||
@ -901,8 +901,8 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
Returns:
|
||||
`int`: Number of patches per image.
|
||||
"""
|
||||
split_image = images_kwargs.get("split_image", self.split_image)
|
||||
max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
|
||||
split_image = images_kwargs["split_image"] if "split_image" in images_kwargs else self.split_image
|
||||
max_image_size = images_kwargs["max_image_size"] if "max_image_size" in images_kwargs else self.max_image_size
|
||||
|
||||
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
||||
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
||||
@ -1431,15 +1431,29 @@ class AriaModel(LlavaModel):
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
vision_feature_layer=self.config.vision_feature_layer,
|
||||
)
|
||||
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
|
||||
n_image_features = n_images * n_features_per_image
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self._get_image_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
|
||||
@ -81,7 +81,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("codegen", "CodeGenConfig"),
|
||||
("cohere", "CohereConfig"),
|
||||
("cohere2", "Cohere2Config"),
|
||||
("cohere2_vision", "Cohere2VisionConfig"),
|
||||
("colpali", "ColPaliConfig"),
|
||||
("colqwen2", "ColQwen2Config"),
|
||||
("conditional_detr", "ConditionalDetrConfig"),
|
||||
@ -172,7 +171,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("gpt_neo", "GPTNeoConfig"),
|
||||
("gpt_neox", "GPTNeoXConfig"),
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
|
||||
("gpt_oss", "GptOssConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("gptsan-japanese", "GPTSanJapaneseConfig"),
|
||||
("granite", "GraniteConfig"),
|
||||
@ -245,7 +243,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("mixtral", "MixtralConfig"),
|
||||
("mlcd", "MLCDVisionConfig"),
|
||||
("mllama", "MllamaConfig"),
|
||||
("mm-grounding-dino", "MMGroundingDinoConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("mobilenet_v1", "MobileNetV1Config"),
|
||||
("mobilenet_v2", "MobileNetV2Config"),
|
||||
@ -479,7 +476,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("codegen", "CodeGen"),
|
||||
("cohere", "Cohere"),
|
||||
("cohere2", "Cohere2"),
|
||||
("cohere2_vision", "Cohere2Vision"),
|
||||
("colpali", "ColPali"),
|
||||
("colqwen2", "ColQwen2"),
|
||||
("conditional_detr", "Conditional DETR"),
|
||||
@ -578,7 +574,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("gpt_neo", "GPT Neo"),
|
||||
("gpt_neox", "GPT NeoX"),
|
||||
("gpt_neox_japanese", "GPT NeoX Japanese"),
|
||||
("gpt_oss", "GptOss"),
|
||||
("gptj", "GPT-J"),
|
||||
("gptsan-japanese", "GPTSAN-japanese"),
|
||||
("granite", "Granite"),
|
||||
@ -660,7 +655,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("mlcd", "MLCD"),
|
||||
("mllama", "Mllama"),
|
||||
("mluke", "mLUKE"),
|
||||
("mm-grounding-dino", "MM Grounding DINO"),
|
||||
("mms", "MMS"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
("mobilenet_v1", "MobileNetV1"),
|
||||
|
||||
@ -72,14 +72,13 @@ else:
|
||||
("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
|
||||
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("cohere2_vision", ("Cohere2VisionImageProcessorFast",)),
|
||||
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
||||
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
|
||||
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
|
||||
("deepseek_vl", ("DeepseekVLImageProcessor")),
|
||||
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor")),
|
||||
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
||||
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
||||
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
||||
@ -113,7 +112,7 @@ else:
|
||||
("imagegpt", ("ImageGPTImageProcessor",)),
|
||||
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
|
||||
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
|
||||
("janus", ("JanusImageProcessor")),
|
||||
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
@ -130,7 +129,6 @@ else:
|
||||
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("mllama", ("MllamaImageProcessor",)),
|
||||
("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
||||
("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
|
||||
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
|
||||
("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
||||
|
||||
@ -77,7 +77,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("codegen", "CodeGenModel"),
|
||||
("cohere", "CohereModel"),
|
||||
("cohere2", "Cohere2Model"),
|
||||
("cohere2_vision", "Cohere2VisionModel"),
|
||||
("conditional_detr", "ConditionalDetrModel"),
|
||||
("convbert", "ConvBertModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
@ -163,7 +162,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neo", "GPTNeoModel"),
|
||||
("gpt_neox", "GPTNeoXModel"),
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
|
||||
("gpt_oss", "GptOssModel"),
|
||||
("gptj", "GPTJModel"),
|
||||
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
||||
("granite", "GraniteModel"),
|
||||
@ -234,7 +232,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mixtral", "MixtralModel"),
|
||||
("mlcd", "MLCDVisionModel"),
|
||||
("mllama", "MllamaModel"),
|
||||
("mm-grounding-dino", "MMGroundingDinoModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
@ -632,7 +629,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neo", "GPTNeoForCausalLM"),
|
||||
("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
||||
("gpt_oss", "GptOssForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("granite", "GraniteForCausalLM"),
|
||||
("granitemoe", "GraniteMoeForCausalLM"),
|
||||
@ -715,7 +711,6 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("aimv2_vision_model", "Aimv2VisionModel"),
|
||||
("beit", "BeitModel"),
|
||||
("bit", "BitModel"),
|
||||
("cohere2_vision", "Cohere2VisionModel"),
|
||||
("conditional_detr", "ConditionalDetrModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
("convnextv2", "ConvNextV2Model"),
|
||||
@ -949,7 +944,6 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("blip", "BlipForConditionalGeneration"),
|
||||
("blip-2", "Blip2ForConditionalGeneration"),
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
("cohere2_vision", "Cohere2VisionForConditionalGeneration"),
|
||||
("deepseek_vl", "DeepseekVLForConditionalGeneration"),
|
||||
("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
|
||||
("emu3", "Emu3ForConditionalGeneration"),
|
||||
@ -1060,7 +1054,6 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Object Detection mapping
|
||||
("grounding-dino", "GroundingDinoForObjectDetection"),
|
||||
("mm-grounding-dino", "MMGroundingDinoForObjectDetection"),
|
||||
("omdet-turbo", "OmDetTurboForObjectDetection"),
|
||||
("owlv2", "Owlv2ForObjectDetection"),
|
||||
("owlvit", "OwlViTForObjectDetection"),
|
||||
|
||||
@ -60,7 +60,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clip", "CLIPProcessor"),
|
||||
("clipseg", "CLIPSegProcessor"),
|
||||
("clvp", "ClvpProcessor"),
|
||||
("cohere2_vision", "Cohere2VisionProcessor"),
|
||||
("colpali", "ColPaliProcessor"),
|
||||
("colqwen2", "ColQwen2Processor"),
|
||||
("deepseek_vl", "DeepseekVLProcessor"),
|
||||
@ -100,7 +99,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mistral3", "PixtralProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("mm-grounding-dino", "GroundingDinoProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
|
||||
@ -300,7 +300,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
|
||||
("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
||||
("granite", ("GPT2Tokenizer", None)),
|
||||
@ -431,7 +430,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
||||
@ -485,8 +485,8 @@ class AutoformerAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
||||
@ -33,9 +33,9 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Cohere2Config`):
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
@ -82,7 +82,7 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = (
|
||||
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
|
||||
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
||||
)
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
@ -99,7 +99,7 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "cohere2"
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["cohere2"]()
|
||||
|
||||
@ -32,8 +32,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
||||
from ..auto import AutoModel
|
||||
from .configuration_aya_vision import AyaVisionConfig
|
||||
|
||||
@ -100,10 +99,6 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
_can_compile_fullgraph = False
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": "DecoderLayer",
|
||||
"attentions": "Attention",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -242,31 +237,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
def get_placeholder_mask(
|
||||
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
||||
):
|
||||
"""
|
||||
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
||||
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
return special_image_mask
|
||||
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -279,9 +250,17 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, AyaVisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
@ -303,10 +282,24 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self.get_placeholder_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
@ -315,6 +308,9 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -361,7 +357,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
return self.model.get_decoder
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
|
||||
@ -32,8 +32,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from .configuration_aya_vision import AyaVisionConfig
|
||||
|
||||
|
||||
@ -92,10 +91,6 @@ class AyaVisionMultiModalProjector(nn.Module):
|
||||
|
||||
class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
|
||||
_can_compile_fullgraph = False
|
||||
_can_record_outputs = {
|
||||
"hidden_states": "DecoderLayer",
|
||||
"attentions": "Attention",
|
||||
}
|
||||
|
||||
|
||||
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
@ -163,7 +158,7 @@ class AyaVisionModel(LlavaModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -176,9 +171,17 @@ class AyaVisionModel(LlavaModel):
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, AyaVisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
@ -200,10 +203,24 @@ class AyaVisionModel(LlavaModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self.get_placeholder_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
@ -212,6 +229,9 @@ class AyaVisionModel(LlavaModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -269,7 +269,7 @@ class BarkConfig(PretrainedConfig):
|
||||
self.semantic_config = BarkSemanticConfig(**semantic_config)
|
||||
self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config)
|
||||
self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config)
|
||||
codec_model_type = codec_config.get("model_type", "encodec")
|
||||
codec_model_type = codec_config["model_type"] if "model_type" in codec_config else "encodec"
|
||||
self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config)
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
@ -248,8 +248,8 @@ class BertSelfAttention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(current_states)
|
||||
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
||||
@ -378,8 +378,8 @@ class BertSdpaSelfAttention(BertSelfAttention):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = (
|
||||
self.key(current_states)
|
||||
@ -414,7 +414,9 @@ class BertSdpaSelfAttention(BertSelfAttention):
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
|
||||
@ -110,8 +110,8 @@ class BertGenerationSelfAttention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(current_states)
|
||||
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
||||
|
||||
@ -96,7 +96,9 @@ def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False):
|
||||
|
||||
name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}"
|
||||
|
||||
name = "/".join([_TRIVIA_QA_MAPPING.get(x, x) for x in name_items])[:-2] # remove last :0 in variable
|
||||
name = "/".join([_TRIVIA_QA_MAPPING[x] if x in _TRIVIA_QA_MAPPING else x for x in name_items])[
|
||||
:-2
|
||||
] # remove last :0 in variable
|
||||
|
||||
if "self/attention/output" in name:
|
||||
name = name.replace("self/attention/output", "output")
|
||||
@ -338,8 +340,8 @@ class BigBirdSelfAttention(nn.Module):
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = past_key_value.layers[self.layer_idx].values
|
||||
key_layer = past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = (
|
||||
self.key(current_states)
|
||||
|
||||
@ -103,7 +103,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(i in k for i in ["dense", "query", "key", "value"]):
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
@ -116,7 +116,7 @@ def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPeg
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(i in k for i in ["dense", "query", "key", "value"]):
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
|
||||
@ -150,8 +150,8 @@ class BigBirdPegasusSelfAttention(nn.Module):
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = past_key_value.layers[self.layer_idx].values
|
||||
key_layer = past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = (
|
||||
self.key(current_states)
|
||||
|
||||
@ -252,7 +252,7 @@ class BitEmbeddings(nn.Module):
|
||||
else:
|
||||
self.pad = nn.ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
|
||||
|
||||
if config.layer_type != "preactivation":
|
||||
if not config.layer_type == "preactivation":
|
||||
self.norm = BitGroupNormActivation(config, num_channels=config.embedding_size)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
@ -176,8 +176,8 @@ class BlipTextSelfAttention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = (
|
||||
self.key(current_states)
|
||||
|
||||
@ -304,7 +304,7 @@ class Blip2Config(PretrainedConfig):
|
||||
|
||||
self.vision_config = Blip2VisionConfig(**vision_config)
|
||||
self.qformer_config = Blip2QFormerConfig(**qformer_config)
|
||||
text_model_type = text_config.get("model_type", "opt")
|
||||
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
||||
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
|
||||
@ -1455,21 +1455,6 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
|
||||
return query_outputs
|
||||
|
||||
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
||||
"""
|
||||
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
return special_image_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1560,8 +1545,16 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||
special_image_mask, language_model_inputs
|
||||
)
|
||||
@ -1945,21 +1938,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
||||
"""
|
||||
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
return special_image_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2064,8 +2042,16 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||
special_image_mask, language_model_inputs
|
||||
)
|
||||
|
||||
@ -460,8 +460,8 @@ class BridgeTowerSelfAttention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(current_states)
|
||||
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
||||
|
||||
@ -198,8 +198,8 @@ class CamembertSelfAttention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(current_states)
|
||||
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
||||
@ -329,8 +329,8 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = (
|
||||
self.key(current_states)
|
||||
@ -365,7 +365,9 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention):
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user