Compare commits

..

1 Commits

Author SHA1 Message Date
d8d0770674 test: bump pytest-asyncio to >=1.2.0 2025-11-10 10:47:29 +01:00
882 changed files with 15332 additions and 19673 deletions

View File

@ -46,8 +46,8 @@ jobs:
- run: uv pip install -U -e .
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
- run: python utils/tests_fetcher.py --filter_tests || true
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
- run: python utils/tests_fetcher.py --filter_tests
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
- run: |
if [ ! -s test_preparation/generated_config.yml ]; then
@ -98,8 +98,8 @@ jobs:
- run: uv pip install -U -e .
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
- run: python utils/tests_fetcher.py --filter_tests || true
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
- run: python utils/tests_fetcher.py --filter_tests
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
- run: |
if [ ! -s test_preparation/generated_config.yml ]; then

View File

@ -40,6 +40,7 @@ jobs:
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
working-directory: /transformers
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
- name: Run benchmark

View File

@ -125,9 +125,8 @@ If you're contributing a **vision-language model** (or any multimodal model that
All new models should use the modular architecture pattern. Create a `modular_<model_name>.py` file using the modular model converter:
- Use the CLI, [`transformers add-new-model-like`](https://github.com/huggingface/transformers/blob/main/src/transformers/cli/add_new_model_like.py) to generate a modular skeleton and get started
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well. [Modular guide](./modular_transformers#implementing-a-modular-file) shows a quick way to set up a modular file.
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well.
- Reuse existing patterns from similar models as much as possible
- You can make the model compatible with inference engines such as vLLM or SGLang, and enable zero-effort integration. See specific requirements for model implementation in ["Transformers modeling backend"](./transformers_as_backend#multimodal-models)
To verify your modular file is correct, run:

View File

@ -45,7 +45,6 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py

View File

@ -117,6 +117,8 @@ def flush_memory():
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
gc.collect()

View File

@ -1,4 +1,4 @@
FROM rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
FROM rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
LABEL maintainer="Hugging Face"
ARG DEBIAN_FRONTEND=noninteractive

View File

@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
@ -533,9 +533,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
```
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf

View File

@ -118,7 +118,7 @@
- local: tools
title: Tools
- local: transformers_as_backend
title: Transformers as modeling backend
title: Inference server backends
- local: continuous_batching
title: Continuous Batching
title: Inference
@ -1008,8 +1008,6 @@
title: AltCLIP
- local: model_doc/aria
title: Aria
- local: model_doc/audioflamingo3
title: AudioFlamingo3
- local: model_doc/aya_vision
title: AyaVision
- local: model_doc/blip
@ -1066,8 +1064,6 @@
title: Gemma3n
- local: model_doc/git
title: GIT
- local: model_doc/glm46v
title: Glm46V
- local: model_doc/glm4v
title: glm4v
- local: model_doc/glm4v_moe

View File

@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
@ -339,9 +339,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
```
### Convert checkpoints to Transformers

View File

@ -1,402 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
*This model was released on 2025-07-10 and added to Hugging Face Transformers on 2025-11-11.*
# Audio Flamingo 3
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## Overview
Audio Flamingo 3 (AF3) is a fully open large audiolanguage model designed for robust understanding and reasoning over speech, environmental sounds, and music. AF3 pairs a Whisper-style audio encoder with a causal language model and performs replace-in-place audiotext fusion: the processor aligns post-pool audio frames to a dedicated placeholder token and the model replaces those token slots with projected audio embeddings during the forward pass.
The model checkpoint is available at: [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
Highlights:
- Unified audio encoder across speech, sound, and music.
- **Long-audio support via windowing and post-pool alignment (up to 10 minutes maximum).** The model processes audio in 30-second windows with a hard limit of 20 windows (10 minutes total). Audio longer than 10 minutes will be truncated.
- Deterministic fusion that preserves sequence length by replacing audio placeholder tokens with audio embeddings.
This model was contributed by [Lasha Koroshinadze](https://huggingface.co/lashahub) and [Eric Bezzam](https://huggingface.co/bezzam).
### Paper
[Audio Flamingo 3](https://huggingface.co/papers/2507.08128): Advancing Audio Intelligence with Fully Open Large Audio Language Models
A. Goel, S. Ghosh, J. Kim, S. Kumar, Z. Kong, S. Lee, C.-H. H. Yang, R. Duraiswami, D. Manocha, R. Valle, B. Catanzaro
NVIDIA and University of Maryland
Project: https://research.nvidia.com/labs/adlr/AF3/
## Usage
### Audio Instruct Mode
The model supports audio-text instructions, including multi-turn interactions, all processed in batches.
➡️ audio + text instruction
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe the input speech."},
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
],
}
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(decoded_outputs)
```
➡️ multi-turn:
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Instruction: How does the tone of female speech change throughout the audio? Choose the correct option among the options below: (A) Sad to happy (B) Happy to sad (C) Neutral to happy (D) Happy to neutral.",
},
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/000000786159.31.wav"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "(A) Sad to happy"}],
},
{
"role": "user",
"content": [
{"type": "text", "text": "Why do you think so?"},
],
},
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(decoded_outputs)
```
➡️ text only:
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
}
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(decoded_outputs)
```
➡️ audio only:
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
conversation = [
{
"role": "user",
"content": [
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
],
}
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(decoded_outputs)
```
➡️ batched inference!
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
conversations = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe the input speech."},
{
"type": "audio",
"path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
},
],
}
],
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
},
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
],
}
],
]
inputs = processor.apply_chat_template(
conversations,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(decoded_outputs)
```
➡️ Training:
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
model.train()
conversation = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe the input speech."},
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "The transcription of the audio is 'summer follows spring the days grow longer and the nights are warm'."}],
}
],
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
},
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "The transcription of the audio is 'some transcription of the audio'."}],
}
]
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
output_labels=True,
).to(model.device)
loss = model(**inputs).loss
loss.backward()
```
➡️ transcription shortcut
```python
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
model_id = "nvidia/audio-flamingo-3-hf"
processor = AutoProcessor.from_pretrained(model_id)
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
inputs = processor.apply_transcription_request(audio="https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, strip_prefix=True)
print(decoded_outputs)
```
The model is trained to emit transcriptions prefixed with assistant framing such as `The spoken content of the audio is "<text>".`. Use `strip_prefix=True` (as shown above) to remove the fixed assistant sentence and surrounding quotes so that only the transcription remains.
## How the model works
### Architecture
* **AudioFlamingo3Encoder**
Whisper-style feature extractor + encoder → average-pool over time (stride 2) → LayerNorm.
Produces per-frame hidden states at the post-pool rate.
* **AudioFlamingo3MultiModalProjector**
A small MLP that maps encoder features to the language models hidden size.
* **AudioFlamingo3ForConditionalGeneration**
A causal language model that accepts text embeddings where each audio placeholder token slot is replaced, in place, by an audio frame embedding. No sequence-length change is introduced by fusion.
### Processor-level alignment
1. Each raw waveform is split into fixed-length windows based on the feature extractors `chunk_length` (seconds) and `sampling_rate` (Hz).
2. For each window, the processor computes the number of post-pool frames `post_pool_len` that the encoder will output (matching the conv/pool schedule).
3. The processor expands the audio placeholder token by the total number of post-pool frames across all windows.
4. The model later replaces those token positions with the corresponding projected audio embeddings.
## Usage patterns
### Transcription shortcut
For automatic speech recognition you can skip writing the default instruction each time and call
[`~transformers.AudioFlamingo3Processor.apply_transcription_request`]:
```python
inputs = processor.apply_transcription_request(audio=audio_array)
```
Pass `prompt="Transcribe the input speech."` (or a list of prompts for batch audio) to customize the instruction while
keeping the audio placeholder handling.
`audio` accepts in-memory arrays, local file paths, or URLs. Any processor kwargs (`text_kwargs`, `audio_kwargs`, etc.)
are forwarded, so you can tweak padding or tensor formats just like when calling `processor(...)`.
## Long audio and windowing
**Important: Maximum audio length is 10 minutes.** Audio longer than this will be truncated.
* The default setup processes 30-second windows at 16 kHz mono.
* **The processor enforces a hard limit of 20 windows per sample, resulting in a maximum of 10 minutes of audio (20 windows × 30 seconds).**
* For each window:
* `mel_len` is the padded mel length.
* A conv stack reduces time as `conv_output_len = (mel_len - 1) // 2 + 1`.
* Post-pool frames per window: `post_pool_len = (conv_output_len - 2) // 2 + 1`.
* An audio placeholder token is expanded to the sum of `post_pool_len` across all windows.
## Padding, attention, and caching
* **Left padding vs right padding**
For generation with mixed prompt lengths in a batch, left padding is usually preferable.
For training, right padding is common; AF3s fusion mechanism itself is padding-agnostic because it replaces in place.
* **Attention masks**
The processor returns `attention_mask` (text) and `input_features_mask` (audio). The model builds an internal 4-D mask on the encoders pre-pool axis with negative infinity at pad positions.
* **Caching**
During generation, `input_features` and `input_features_mask` are only passed on the first step. Subsequent steps use cached keys/values from the language model.
## Troubleshooting
* Empty or truncated outputs when batching
Use left padding for batched generation and decode only the new tokens after the prompt length, as shown in the quickstart.
## AudioFlamingo3Config
[[autodoc]] AudioFlamingo3Config
## AudioFlamingo3EncoderConfig
[[autodoc]] AudioFlamingo3EncoderConfig
## AudioFlamingo3Processor
[[autodoc]] AudioFlamingo3Processor
## AudioFlamingo3Encoder
[[autodoc]] AudioFlamingo3Encoder
- forward
## AudioFlamingo3ForConditionalGeneration
[[autodoc]] AudioFlamingo3ForConditionalGeneration
- forward

View File

@ -169,9 +169,6 @@ print("Pooled output shape:", pooled_output.shape)
[[autodoc]] DINOv3ViTModel
- forward
## DINOv3ViTBackbone
[[autodoc]] DINOv3ViTBackbone
## DINOv3ConvNextModel
[[autodoc]] DINOv3ConvNextModel

View File

@ -1,34 +0,0 @@
# GLM-4.6V
## Glm46VConfig
[[autodoc]] Glm46VConfig
## Glm46VImageProcessor
[[autodoc]] Glm46VImageProcessor
- preprocess
## Glm46VVideoProcessor
[[autodoc]] Glm46VVideoProcessor
- preprocess
## Glm46VImageProcessorFast
[[autodoc]] Glm46VImageProcessorFast
- preprocess
## Glm46VProcessor
[[autodoc]] Glm46VProcessor
## Glm46VModel
[[autodoc]] Glm46VModel
- forward
## Glm46VForConditionalGeneration
[[autodoc]] Glm46VForConditionalGeneration
- forward

View File

@ -170,11 +170,6 @@ print(output_text)
[[autodoc]] Glm4vConfig
## Glm4vVisionConfig
[[autodoc]] Glm4vVisionConfig
## Glm4vTextConfig
[[autodoc]] Glm4vTextConfig
@ -198,11 +193,6 @@ print(output_text)
[[autodoc]] Glm4vProcessor
## Glm4vVisionModel
[[autodoc]] Glm4vVisionModel
- forward
## Glm4vTextModel
[[autodoc]] Glm4vTextModel

View File

@ -22,7 +22,7 @@ rendered properly in your Markdown viewer.
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> </div>
</div>
# Glm4vMoeMoe
# Glm4vMoe
## Overview
@ -48,20 +48,10 @@ The model also introduces a **Thinking Mode** switch, allowing users to balance
[[autodoc]] Glm4vMoeConfig
## Glm4vMoeVisionConfig
[[autodoc]] Glm4vMoeVisionConfig
## Glm4vMoeTextConfig
[[autodoc]] Glm4vMoeTextConfig
## Glm4vMoeVisionModel
[[autodoc]] Glm4vMoeVisionModel
- forward
## Glm4vMoeTextModel
[[autodoc]] Glm4vMoeTextModel
@ -75,4 +65,4 @@ The model also introduces a **Thinking Mode** switch, allowing users to balance
## Glm4vMoeForConditionalGeneration
[[autodoc]] Glm4vMoeForConditionalGeneration
- forward
- forward

View File

@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=1,
video_fps=1,
# kwargs to be passed to `Qwen2-5-OmniProcessor`
padding=True,
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=1,
video_fps=1,
# kwargs to be passed to `Qwen2-5-OmniProcessor`
padding=True,

View File

@ -54,7 +54,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_co
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
@ -63,7 +63,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
# We can also omit the audio_bos and audio_eos tokens
prompt = "<|AUDIO|>Generate the caption in English:"
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
@ -106,7 +106,7 @@ for message in conversation:
sr=processor.feature_extractor.sampling_rate)[0]
)
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
inputs.input_ids = inputs.input_ids.to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
@ -156,7 +156,7 @@ for message in conversation:
sr=processor.feature_extractor.sampling_rate)[0]
)
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
inputs.input_ids = inputs.input_ids.to(model.device)
generate_ids = model.generate(**inputs, max_length=256)
@ -213,7 +213,7 @@ for conversation in conversations:
sr=processor.feature_extractor.sampling_rate)[0]
)
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
inputs['input_ids'] = inputs['input_ids'].to(model.device)
inputs.input_ids = inputs.input_ids.to(model.device)

View File

@ -80,7 +80,7 @@ inputs = processor.apply_chat_template(
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=1,
video_fps=1,
# kwargs to be passed to `Qwen3OmniMoeProcessor`
padding=True,
@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=1,
video_fps=1,
# kwargs to be passed to `Qwen3OmniMoeProcessor`
padding=True,
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
tokenize=True,
return_dict=True,
return_tensors="pt",
fps=1,
video_fps=1,
# kwargs to be passed to `Qwen3OmniMoeProcessor`
padding=True,

View File

@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
>>> audio_sample = next(iter(dataset))["audio"]
>>> # now, process it
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
>>> # now, process some English test as well
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")

View File

@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
>>> audio_sample = next(iter(dataset))["audio"]
>>> # now, process it
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
>>> # now, process some English text as well
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")

View File

@ -159,7 +159,7 @@ conversation3 = [
conversations = [conversation1, conversation2, conversation3]
inputs = processor.apply_chat_template(
conversations,
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,

View File

@ -1,6 +1,6 @@
# Contributing a new model to Transformers
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance. We recommend to go through [general contribution guidelines for new models](./contributing#do-you-want-to-implement-a-new-model) before diving into the details here.
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance.
One of Transformers' core design feature is the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) policy. Model components - such as attention layers - are repeated across many files and any independent implementations tend to diverge as fixes and changes are applied to specific parts of the code.

View File

@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.

View File

@ -329,7 +329,7 @@ from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, int4_packing_format="plain_int32")
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
@ -342,7 +342,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
@ -395,7 +395,7 @@ from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import Int4CPULayout
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout(), int4_packing_format="opaque")
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
@ -422,7 +422,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
#### 1. Skip quantization for certain layers
With `FqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
```py
import torch
@ -430,11 +430,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
model_id = "meta-llama/Llama-3.1-8B-Instruct"
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
config = Int4WeightOnlyConfig(group_size=128)
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
quant_config = FqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
@ -459,7 +459,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
model_id = "facebook/opt-125m"
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
weight_dtype = torch.int8
granularity = PerAxis(0)
@ -470,7 +470,7 @@ embedding_config = IntxWeightOnlyConfig(
mapping_type=mapping_type,
)
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
quant_config = FqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
# set `include_embedding` to True in order to include embedding in quantization
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
@ -521,7 +521,7 @@ from torchao.quantization import (
IntxWeightOnlyConfig,
PerRow,
PerAxis,
FqnToConfig,
ModuleFqnToConfig,
Float8Tensor,
Int4TilePackedTo4dTensor,
IntxUnpackedToInt8Tensor,
@ -550,7 +550,7 @@ qconfig_dict = {
"_default": intxwo,
}
quant_config = FqnToConfig(qconfig_dict)
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,

View File

@ -14,9 +14,9 @@ rendered properly in your Markdown viewer.
-->
# Transformers as modeling backend
# Inference server backends
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a new model architecture from scratch for each inference server, you only need a model definition in `transformers`, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a model for each inference server, you only need one model, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
With Transformers as a backend, you can also serve any model - including custom and Hub-hosted models - without waiting for native support.
@ -157,13 +157,57 @@ class MyConfig(PreTrainedConfig):
### Multimodal models
For multimodal models, you need to include a few more changes on top of the general recommendations outlined in ["contribuiting a model"](./contributing#vision-language-model-contribution-checklist). These rules ensure that your model integrates properly and enables processing multimodal data.
For multimodal models, you need to include a few more changes on top of the general recommendations. These rules ensure that your model integrates properly with multimodal data.
1. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. This placeholder token is the same token used in the input prompt to denote images and used in model code to scatter image features.
1. A multimodal model requires a base `MyMultiModalModel` class to handle multimodal fusion without a language modeling head and a separate generative class that adds a head.
2. The processing class needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholders between `<image>` tokens such as row or column tokens don't count as image placeholders. Only tokens that are actually replaced by image features later in modeling should be counted!
The base model needs to implement the `get_image_features()` method to accept image pixel values and return encoded outputs. These are later merged with the language embeddings and don't require any postprocessing. The shape of the returned features must match the number of input images. If a vision encoder returns variable-length outputs (patch-based), return a list of 2D tensors of size `(image_seq_len, image_dim)` for each image.
3. The processor needs to check the value of `return_mm_token_type_ids` and return `mm_token_type_ids` to indicate whether each position is a text token (`0`), image placeholder token (`1`) or video placeholder token (`2`). Each multimodal token type ID sequence must be contiguous without breaks between consecutive tokens, therefore special tokens for begin/end/row/column must be treated as placeholders.
Expand the code below for an example.
<details>
<summary>modeling_my_multimodal_model.py</summary>
```python
from transformers.generation import GenerationMixin
class MyMultimodalModel(MyMultimodalPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.language_model = AutoModel.from_config(config.text_config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multimodal_projection = nn.Linear(vision_dim, text_dim)
def get_image_features(self, pixel_values):
return self.vision_tower(pixel_values).last_hidden_states
def forward(self, input_ids, pixel_values, **kwargs):
# process your inputs
return MyModelOutputWithPast(
last_hidden_state=last_hidden_state,
image_hidden_states=image_features,
[...]
)
class MyMultimodalModelForConditionalGeneration(MyMultimodalPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.model = MyMultimodalModel(config)
self.lm_head = nn.Linear(hidden_dim, vocab_size)
```
</details>
2. A multimodal model config must be nested with the following fields.
* text_config: decoder language model config
* vision_config: vision encoder config
* image_token_id: ID of the image placeholder token used in the input to indicate image position
3. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. The placeholder token is the same token used in the input prompt and to mask scatter image features.
The processing class also needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholder for row and column tokens don't count as image placeholders. Only the tokens that are actually replaced by image features are computed.
Finally, when `return_mm_token_type_ids=True`, the class has to return `mm_token_type_ids` to indicate whether each position is a text token (`0`) or image placeholder token (`1`). Each image's token type IDs must be contiguous with no breaks between consecutive ones.
Expand the code below for an example.
@ -202,5 +246,5 @@ class MyMultimodalProcessor(ProcessorMixin):
## Resources
* Read the [Transformers modeling backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers modeling backend in vLLM.
* Read the [Transformers modeling backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers modeling backend in SGLang.
* Read the [Transformers backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers backend in vLLM.
* Read the [Transformers backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers backend in SGLang.

View File

@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.
Per quanto riguarda il modello Transfo-XL:
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.
Per quanto riguarda le pipeline:

View File

@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
@ -431,9 +431,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
```
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。

View File

@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
@ -371,9 +371,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
```
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q``module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.

View File

@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.

View File

@ -502,10 +502,16 @@ class DummyBertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -530,18 +536,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.zero_()
module.bias.data.zero_()
@auto_docstring(

View File

@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.zero_()
module.weight.data.zero_()
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):

View File

@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
def token_type_ids_mask_function(
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config):
@ -440,15 +440,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
if self.language_model._tied_weights_keys is not None:
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
self.post_init()
def get_input_embeddings(self):

View File

@ -505,10 +505,16 @@ class RobertaLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -533,18 +539,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.zero_()
module.weight.fill_(1.0)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.zero_()
module.bias.data.zero_()
@auto_docstring(

View File

@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points") and not self.config.two_stage:

View File

@ -19,15 +19,7 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
if self.language_model._tied_weights_keys is not None:
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
self.post_init()

View File

@ -27,6 +27,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from random import randint
from typing import Optional
@ -179,11 +180,29 @@ class ModelArguments:
)
},
)
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
warnings.warn(
"The argument `--freeze_feature_extractor` is deprecated and "
"will be removed in a future version. Use `--freeze_feature_encoder` "
"instead. Setting `freeze_feature_encoder==True`.",
FutureWarning,
)
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
raise ValueError(
"The argument `--freeze_feature_extractor` is deprecated and "
"should not be used in combination with `--freeze_feature_encoder`. "
"Only make use of `--freeze_feature_encoder`."
)
def main():
# See all possible arguments in src/transformers/training_args.py

View File

@ -17,7 +17,6 @@ import contextlib
import json
import os
import time
from itertools import cycle
from typing import Optional
import datasets
@ -30,32 +29,42 @@ from transformers.generation import GenerationConfig
from transformers.generation.continuous_batching.requests import logger
def generate_without_cb(
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
SLIDING_WINDOW = 0
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
SKIP_SPECIAL_TOKENS = False
def generate_simple(
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
) -> dict[str, str]:
# Setup model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
attn_impl = {
"sdpa": "sdpa",
"eager": "eager",
"paged_attention": "eager", # TODO: this does not work on AMD docker
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
"kernels-community/flash-attn": "eager",
}[attn_impl]
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
model = model.cuda().eval()
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
model.config.sliding_window = sliding_window
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate one by one
if getattr(model.config, "sliding_window", None) is not None:
model.config.sliding_window = SLIDING_WINDOW
decoded_outputs = {}
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
input_ids = torch.tensor([input_ids]).to("cuda")
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
)
# attention_mask = torch.ones_like(input_ids)
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
generated_tokens = outputs[0][input_ids.shape[1] :]
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
decoded_outputs[key] = decoded_output
return decoded_outputs
def maybe_setup_metrics(use_metrics: bool) -> None:
if not use_metrics:
return
def setup_metrics():
try:
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -110,14 +119,16 @@ def batch_generate(
token_count = 0
data = []
for i, request in enumerate(batch_outputs):
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
# The key is used to tie back to the output of unbatched generation
key = " ".join(map(str, batch_outputs[request].prompt_ids))
data.append({"input": input_text, "key": key})
# Try to decode the output
try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
output_text = tokenizer.decode(
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
)
token_count += len(batch_outputs[request].generated_tokens[1:])
data[-1]["cb_outputs"] = output_text
except Exception as e:
@ -127,7 +138,14 @@ def batch_generate(
# Display sample if asked
if i < displayed_samples:
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
print(f"{request} Output: {output_text}")
else:
print(f"{request} Input: {input_text}")
print("[WARN]")
print(f"{request} Output was empty!")
# Compare with classic generate if asked
if expected_outputs is not None:
@ -164,102 +182,75 @@ def batch_generate(
if __name__ == "__main__":
# Parse args
parser = argparse.ArgumentParser()
# Continuous batching parameters
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
# Model parameters
parser.add_argument("--sliding-window", type=int, default=0)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
# Performance parameters
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--output-file", type=str, default=None)
parser.add_argument("--compare", action="store_true")
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--profile", type=str, default=None)
args = parser.parse_args()
# Create model
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
has_system_role = args.sliding_window == 0
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
model = model.cuda().eval()
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
model.config.sliding_window = args.sliding_window
# Set up diagnostics
# Set log level
logger.setLevel(args.log_level.upper())
maybe_setup_metrics(args.metrics)
# Set up performance
# If turned on, we setup metrics
if args.metrics:
setup_metrics()
# Set matmul precision if not none
if args.matmul_precision != "none":
torch.set_float32_matmul_precision(args.matmul_precision)
# Parse cuda graph argument
if args.cuda_graph is not None:
use_cuda_graph = {
"none": None,
"yes": True, "y": True, "true": True, "t": True, "1": True,
"no": False, "n": False, "false": False, "f": False, "0": False,
}[args.cuda_graph.lower()] # fmt: skip
else:
use_cuda_graph = None
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
use_cuda_graph = {
"none": None, None: None,
"yes": True, "y": True, "true": True, "t": True, "1": True,
"no": False, "n": False, "false": False, "f": False, "0": False,
}[cuda_graph_arg] # fmt: skip
# Prepare model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
attn_implementation=args.attn,
dtype=torch.bfloat16,
)
model = model.cuda().eval()
if getattr(model.config, "sliding_window", None) is not None:
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
model.config.sliding_window = SLIDING_WINDOW
# If turned on, we compile the model
if args.compile:
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
# Prepare tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
if args.add_prefix:
possible_prefixes = [
None,
"You are a bot that solves math problems.",
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
] # fmt: skip
else:
possible_prefixes = [None]
batched_inputs = []
for item, prefix in zip(dataset, cycle(possible_prefixes)):
messages = []
question = item["question"]
if prefix is not None:
if has_system_role:
messages.append({"role": "system", "content": prefix})
else:
question = prefix + "\n\n" + question
messages.append({"role": "user", "content": question})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
batched_inputs.append(inputs["input_ids"])
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
# Prepare generation config
generation_cfg = GenerationConfig(
generation_config = GenerationConfig(
max_new_tokens=512,
use_cuda_graph=use_cuda_graph,
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=args.do_sample,
do_sample=not args.compare,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
@ -267,12 +258,7 @@ if __name__ == "__main__":
)
# If we need to compare, we need to generate the reference outputs
if args.compare:
expected_outputs = generate_without_cb(
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
)
else:
expected_outputs = None
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
# If no output file is provided, we pick a name based on the args
if args.output_file is None:
@ -285,8 +271,8 @@ if __name__ == "__main__":
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
batch_generate(
model,
batched_inputs[: min(5, args.samples)],
generation_cfg,
simple_batch_inputs[: min(5, args.samples)],
generation_config,
tokenizer,
displayed_samples=-1,
)
@ -299,8 +285,8 @@ if __name__ == "__main__":
# Run batch generation
gen_time, tok_per_sec = batch_generate(
model,
batched_inputs,
generation_cfg,
simple_batch_inputs,
generation_config,
tokenizer,
displayed_samples=args.displayed,
output_file=args.output_file,
@ -311,5 +297,5 @@ if __name__ == "__main__":
prof.export_chrome_trace(filename)
# Example usage:
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json

View File

@ -127,7 +127,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -132,7 +132,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -130,7 +130,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -128,7 +128,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the HuggingFace Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -151,7 +151,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -223,7 +223,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -74,7 +74,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
def tearDownClass(cls):
shutil.rmtree(cls.tmpdir)
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_glue_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
@ -148,7 +147,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
@ -177,7 +175,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
@ -206,7 +203,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
@ -309,7 +305,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
@slow
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()

View File

@ -374,7 +374,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
@slow
def test_run_image_classification(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@ -404,7 +403,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
@slow
def test_run_speech_recognition_ctc(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@ -575,7 +573,6 @@ class ExamplesTests(TestCasePlus):
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)
@slow
def test_run_semantic_segmentation(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@ -600,7 +597,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
@slow
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
def test_run_object_detection(self):
tmp_dir = self.get_auto_remove_tmp_dir()
@ -628,7 +624,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["test_map"], 0.1)
@slow
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
def test_run_instance_segmentation(self):
tmp_dir = self.get_auto_remove_tmp_dir()

View File

@ -120,7 +120,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -212,7 +212,7 @@ def parse_args():
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",

View File

@ -50,7 +50,6 @@ checkpoint: 检查点
</p>
<p align="center">
<a href="https://huggingface.co/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
@ -61,7 +60,7 @@ checkpoint: 检查点
<h4 align="center">
<p>
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
<a href="https://github.com/huggingface/transformers/">English</a> |
<b>简体中文</b> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hant.md">繁體中文</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
@ -69,7 +68,7 @@ checkpoint: 检查点
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
@ -82,258 +81,182 @@ checkpoint: 检查点
</h4>
<h3 align="center">
<p>文本、视觉、音频、视频与多模态提供推理与训练的先进预训练模型</p>
<p> Jax、PyTorch 和 TensorFlow 打造的先进的自然语言处理函数库</p>
</h3>
<h3 align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
</h3>
Transformers 充当跨文本、计算机视觉、音频、视频与多模态的最先进机器学习模型的「模型定义框架」,同时覆盖推理与训练
🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用
它将模型的定义集中化,使整个生态系统对该定义达成一致。`transformers` 是跨框架的枢纽:一旦某模型定义被支持,它通常就能兼容多数训练框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorchLightning 等)、推理引擎(如 vLLM、SGLang、TGI 等),以及依赖 `transformers` 模型定义的相关库(如 llama.cpp、mlx 等)
🤗 Transformers 提供了便于快速下载和使用的API让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 [model hub](https://huggingface.co/models) 与社区共享。同时,每个定义的 Python 模块都是完全独立的,便于修改和快速进行研究实验
我们的目标是持续支持新的最先进模型,并通过让模型定义保持简单、可定制且高效来普及其使用
🤗 Transformers 支持三个最热门的深度学习库: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理
目前在 [Hugging Face Hub](https://huggingface.com/models) 上有超过 1M+ 使用 `transformers` 的[模型检查点](https://huggingface.co/models?library=transformers&sort=trending),可随取随用。
今天就去探索 Hub找到一个模型并用 Transformers 立刻开始吧。
## 在线演示
## 安装
你可以直接在模型页面上测试大多数 [model hub](https://huggingface.co/models) 上的模型。 我们也提供了 [私有模型托管、模型版本管理以及推理API](https://huggingface.co/pricing)。
Transformers 支持 Python 3.9+,以及 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
这里是一些例子:
- [用 BERT 做掩码填词](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
- [用 Electra 做命名实体识别](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
- [用 RoBERTa 做自然语言推理](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
- [用 DistilBERT 做问答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
- [用 T5 做翻译](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
使用 [venv](https://docs.python.org/3/library/venv.html) 或 [uv](https://docs.astral.sh/uv/)(一个基于 Rust 的快速 Python 包与项目管理器)创建并激活虚拟环境:
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 团队打造,是一个文本生成的官方 demo。
```py
# venv
python -m venv .my-env
source .my-env/bin/activate
# uv
uv venv .my-env
source .my-env/bin/activate
```
## 如果你在寻找由 Hugging Face 团队提供的定制化支持服务
在虚拟环境中安装 Transformers
```py
# pip
pip install "transformers[torch]"
# uv
uv pip install "transformers[torch]"
```
如果你需要库中的最新改动或计划参与贡献,可从源码安装(注意:最新版可能不稳定;如遇错误,欢迎在 [issues](https://github.com/huggingface/transformers/issues) 中反馈):
```shell
git clone https://github.com/huggingface/transformers.git
cd transformers
# pip
pip install '.[torch]'
# uv
uv pip install '.[torch]'
```
<a target="_blank" href="https://huggingface.co/support">
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
</a><br>
## 快速上手
使用 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 一步上手。`Pipeline` 是一个高级推理类,支持文本、音频、视觉与多模态任务,负责输入预处理并返回适配的输出。
我们为快速使用模型提供了 `pipeline` API。Pipeline 聚合了预训练模型和对应的文本预处理。下面是一个快速使用 pipeline 去判断正负面情绪的例子:
实例化一个用于文本生成的 pipeline指定使用的模型。模型会被下载并缓存方便复用。最后传入文本作为提示
```python
>>> from transformers import pipeline
```py
from transformers import pipeline
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
pipeline("the secret to baking a really good cake is ")
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
# 使用情绪分析 pipeline
>>> classifier = pipeline('sentiment-analysis')
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
```
要与模型进行「聊天」,用法也一致。唯一不同是需要构造一段「聊天历史」(即 `Pipeline` 的输入):
第二行代码下载并缓存了 pipeline 使用的预训练模型,而第三行代码则在给定的文本上进行了评估。这里的答案"正面" (positive) 具有 99 的置信度。
> [!TIP]
> 你也可以直接在命令行与模型聊天:
> ```shell
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
> ```
许多的 NLP 任务都有开箱即用的预训练 `pipeline`。比如说,我们可以轻松的从给定文本中抽取问题答案:
```py
import torch
from transformers import pipeline
``` python
>>> from transformers import pipeline
chat = [
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
]
# 使用问答 pipeline
>>> question_answerer = pipeline('question-answering')
>>> question_answerer({
... 'question': 'What is the name of the repository ?',
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
... })
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
response = pipeline(chat, max_new_tokens=512)
print(response[0]["generated_text"][-1]["content"])
```
展开下方示例,查看 `Pipeline` 在不同模态与任务中的用法
除了给出答案,预训练模型还给出了对应的置信度分数、答案在词符化 (tokenized) 后的文本中开始和结束的位置。你可以从[这个教程](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API 支持的任务
<details>
<summary>自动语音识别</summary>
要在你的任务上下载和使用任意预训练模型也很简单,只需三行代码。这里是 PyTorch 版的示例:
```python
>>> from transformers import AutoTokenizer, AutoModel
```py
from transformers import pipeline
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
>>> outputs = model(**inputs)
```
这里是等效的 TensorFlow 代码:
```python
>>> from transformers import AutoTokenizer, TFAutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
>>> outputs = model(**inputs)
```
</details>
词符化器 (tokenizer) 为所有的预训练模型提供了预处理,并可以直接对单个字符串进行调用(比如上面的例子)或对列表 (list) 调用。它会输出一个你可以在下游代码里使用或直接通过 `**` 解包表达式传给模型的词典 (dict)。
<details>
<summary>图像分类</summary>
模型本身是一个常规的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取决于你的后端),可以常规方式使用。 [这个教程](https://huggingface.co/transformers/training.html)解释了如何将这样的模型整合到经典的 PyTorch 或 TensorFlow 训练循环中,或是如何使用我们的 `Trainer` 训练器API 来在一个新的数据集上快速微调。
<h3 align="center">
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
</h3>
## 为什么要用 transformers
```py
from transformers import pipeline
1. 便于使用的先进模型:
- NLU 和 NLG 上表现优越
- 对教学和实践友好且低门槛
- 高级抽象,只需了解三个类
- 对所有模型统一的API
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
[{"label": "macaw", "score": 0.997848391532898},
{"label": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
"score": 0.0016551691805943847},
{"label": "lorikeet", "score": 0.00018523589824326336},
{"label": "African grey, African gray, Psittacus erithacus",
"score": 7.85409429227002e-05},
{"label": "quail", "score": 5.502637941390276e-05}]
1. 更低计算开销,更少的碳排放:
- 研究人员可以分享已训练的模型而非每次从头开始训练
- 工程师可以减少计算用时和生产环境开销
- 数十种模型架构、两千多个预训练模型、100多种语言支持
1. 对于模型生命周期的每一个部分都面面俱到:
- 训练先进的模型,只需 3 行代码
- 模型在不同深度学习框架间任意转移,随你心意
- 为训练、评估和生产选择最适合的框架,衔接无缝
1. 为你的需求轻松定制专属模型和用例:
- 我们为每种模型架构提供了多个用例来复现原论文结果
- 模型内部结构保持透明一致
- 模型文件可单独使用,方便修改和快速实验
## 什么情况下我不该用 transformers
- 本库并不是模块化的神经网络工具箱。模型文件中的代码特意呈若璞玉,未经额外抽象封装,以便研究人员快速迭代修改而不致溺于抽象和文件跳转之中。
- `Trainer` API 并非兼容任何模型,只为本库之模型优化。若是在寻找适用于通用机器学习的训练循环实现,请另觅他库。
- 尽管我们已尽力而为,[examples 目录](https://github.com/huggingface/transformers/tree/main/examples)中的脚本也仅为用例而已。对于你的特定问题,它们并不一定开箱即用,可能需要改几行代码以适之。
## 安装
### 使用 pip
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下经过测试。
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
首先,用你打算使用的版本的 Python 创建一个虚拟环境并激活。
然后,你需要安装 Flax、PyTorch 或 TensorFlow 其中之一。关于在你使用的平台上安装这些框架,请参阅 [TensorFlow 安装页](https://www.tensorflow.org/install/), [PyTorch 安装页](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安装页](https://github.com/google/flax#quick-install)。
当这些后端之一安装成功后, 🤗 Transformers 可依此安装:
```bash
pip install transformers
```
</details>
如果你想要试试用例或者想在正式发布前使用最新的开发中代码,你得[从源代码安装](https://huggingface.co/docs/transformers/installation#installing-from-source)。
<details>
<summary>视觉问答</summary>
### 使用 conda
<h3 align="center">
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
</h3>
🤗 Transformers 可以通过 conda 依此安装:
```py
from transformers import pipeline
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
pipeline(
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
question="What is in the image?",
)
[{"answer": "statue of liberty"}]
```shell script
conda install conda-forge::transformers
```
</details>
> **_笔记:_** 从 `huggingface` 渠道安装 `transformers` 已被废弃。
## 为什么要用 Transformers
要通过 conda 安装 Flax、PyTorch 或 TensorFlow 其中之一,请参阅它们各自安装页的说明。
1. 易于使用的最先进模型:
- 在自然语言理解与生成、计算机视觉、音频、视频与多模态任务上表现优越。
- 对研究者、工程师与开发者友好且低门槛。
- 少量用户侧抽象,仅需学习三个类。
- 统一的 API使用所有预训练模型体验一致。
## 模型架构
1. 更低计算开销与更小碳足迹:
- 共享已训练的模型,而非每次从零开始训练。
- 减少计算时间与生产环境成本。
- 覆盖数十种模型架构,跨所有模态提供 1M+ 预训练检查点。
🤗 Transformers 支持的[**所有的模型检查点**](https://huggingface.co/models)由[用户](https://huggingface.co/users)和[组织](https://huggingface.co/organizations)上传,均与 huggingface.co [model hub](https://huggingface.co) 无缝整合。
1. 在模型生命周期的每个阶段都可以选用合适的框架:
- 3 行代码即可训练最先进模型。
- 在 PyTorch/JAX/TF2.0 间自由迁移同一个模型。
- 为训练、评估与生产挑选最合适的框架。
目前的检查点数量: ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen)
1. 轻松定制模型或用例:
- 为每个架构提供示例以复现原论文结果。
- 尽可能一致地暴露模型内部。
- 模型文件可独立于库使用,便于快速实验。
🤗 Transformers 目前支持如下的架构: 模型概述请阅[这里](https://huggingface.co/docs/transformers/model_summary).
<a target="_blank" href="https://huggingface.co/enterprise">
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
</a><br>
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器tokenizer敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
## 为什么我不该用 Transformers
- 该库不是一个可自由拼搭的神经网络模块化工具箱。模型文件中的代码刻意减少额外抽象,以便研究者能快速在各个模型上迭代,而无需深入更多抽象或文件跳转。
- 训练 API 优化用于 Transformers 提供的 PyTorch 模型。若需要通用的机器学习训练循环,请使用其它库,如 [Accelerate](https://huggingface.co/docs/accelerate)。
- [示例脚本](https://github.com/huggingface/transformers/tree/main/examples)只是「示例」。它们不一定能直接适配你的具体用例,需要你进行必要的改动。
这些实现均已于多个数据集测试(请参看用例脚本)并应于原版实现表现相当。你可以在用例文档的[此节](https://huggingface.co/docs/transformers/examples)中了解表现的细节。
## 100 个使用 Transformers 的项目
## 了解更多
Transformers 不止是一个使用预训练模型的工具包,它还是围绕 Hugging Face Hub 构建的项目社区。我们希望 Transformers 能助力开发者、研究人员、学生、老师、工程师与任何人构建理想项目。
为庆祝 Transformers 获得 100,000 颗星,我们制作了 [awesome-transformers](./awesome-transformers.md) 页面,展示了 100 个由社区构建的优秀项目。
如果你拥有或使用某个项目,认为它应该在列表中出现,欢迎提交 PR 添加它!
## 示例模型
你可以直接在它们的 [Hub 模型页](https://huggingface.co/models) 上测试我们的多数模型。
展开每个模态以查看不同用例中的部分示例模型。
<details>
<summary>音频</summary>
- 使用 [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo) 进行音频分类
- 使用 [Moonshine](https://huggingface.co/UsefulSensors/moonshine) 进行自动语音识别
- 使用 [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks) 进行关键词检索
- 使用 [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16) 进行语音到语音生成
- 使用 [MusicGen](https://huggingface.co/facebook/musicgen-large) 文本到音频生成
- 使用 [Bark](https://huggingface.co/suno/bark) 文本到语音生成
</details>
<details>
<summary>计算机视觉</summary>
- 使用 [SAM](https://huggingface.co/facebook/sam-vit-base) 自动生成掩码
- 使用 [DepthPro](https://huggingface.co/apple/DepthPro-hf) 进行深度估计
- 使用 [DINO v2](https://huggingface.co/facebook/dinov2-base) 进行图像分类
- 使用 [SuperPoint](https://huggingface.co/magic-leap-community/superpoint) 进行关键点检测
- 使用 [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor) 进行关键点匹配
- 使用 [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd) 进行目标检测
- 使用 [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple) 进行姿态估计
- 使用 [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large) 进行通用分割
- 使用 [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large) 进行视频分类
</details>
<details>
<summary>多模态</summary>
- 使用 [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B) 实现音频或文本到文本
- 使用 [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base) 进行文档问答
- 使用 [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) 实现图像或文本到文本
- 使用 [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b) 进行图文描述
- 使用 [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf) 进行基于 OCR 的文档理解
- 使用 [TAPAS](https://huggingface.co/google/tapas-base) 进行表格问答
- 使用 [Emu3](https://huggingface.co/BAAI/Emu3-Gen) 进行统一的多模态理解与生成
- 使用 [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) 视觉到文本
- 使用 [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) 进行视觉问答
- 使用 [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224) 进行视觉指代表达分割
</details>
<details>
<summary>NLP</summary>
- 使用 [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base) 进行掩码词填充
- 使用 [Gemma](https://huggingface.co/google/gemma-2-2b) 进行命名实体识别NER
- 使用 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) 进行问答
- 使用 [BART](https://huggingface.co/facebook/bart-large-cnn) 进行摘要
- 使用 [T5](https://huggingface.co/google-t5/t5-base) 进行翻译
- 使用 [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B) 进行文本生成
- 使用 [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B) 进行文本分类
</details>
| 章节 | 描述 |
|-|-|
| [文档](https://huggingface.co/docs/transformers/) | 完整的 API 文档和教程 |
| [任务总结](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支持的任务 |
| [预处理教程](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 来为模型准备数据 |
| [训练和微调](https://huggingface.co/docs/transformers/training) | 在 PyTorch/TensorFlow 的训练循环或 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
| [快速上手:微调和用例脚本](https://github.com/huggingface/transformers/tree/main/examples) | 为各种任务提供的用例脚本 |
| [模型分享和上传](https://huggingface.co/docs/transformers/model_sharing) | 和社区上传和分享你微调的模型 |
| [迁移](https://huggingface.co/docs/transformers/migration) | 从 `pytorch-transformers` 或 `pytorch-pretrained-bert` 迁移到 🤗 Transformers |
## 引用

View File

@ -14,6 +14,43 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
<!---
A useful guide for English-Traditional Chinese translation of Hugging Face documentation
- Add space around English words and numbers when they appear between Chinese characters. E.g., 共 100 多種語言; 使用 transformers 函式庫。
- Use square quotes, e.g.,「引用」
- Some of terms in the file can be found at National Academy for Educational Research (https://terms.naer.edu.tw/), an official website providing bilingual translations between English and Traditional Chinese.
Dictionary
API: API (不翻譯)
add: 加入
checkpoint: 檢查點
code: 程式碼
community: 社群
confidence: 信賴度
dataset: 資料集
documentation: 文件
example: 基本翻譯為「範例」,或依語意翻為「例子」
finetune: 微調
Hugging Face: Hugging Face不翻譯
implementation: 實作
inference: 推論
library: 函式庫
module: 模組
NLP/Natural Language Processing: 以 NLP 出現時不翻譯,以 Natural Language Processing 出現時翻譯為自然語言處理
online demos: 線上Demo
pipeline: pipeline不翻譯
pretrained/pretrain: 預訓練
Python data structures (e.g., list, set, dict): 翻譯為串列,集合,字典,並用括號標註原英文
repository: repository不翻譯
summary: 概覽
token-: token-(不翻譯)
Trainer: Trainer不翻譯
transformer: transformer不翻譯
tutorial: 教學
user: 使用者
-->
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/transformers-logo-dark.svg">
@ -25,7 +62,6 @@ limitations under the License.
</p>
<p align="center">
<a href="https://huggingface.com/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
@ -36,7 +72,7 @@ limitations under the License.
<h4 align="center">
<p>
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
<a href="https://github.com/huggingface/transformers/">English</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hans.md">简体中文</a> |
<b>繁體中文</b> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
@ -44,7 +80,7 @@ limitations under the License.
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
@ -57,261 +93,186 @@ limitations under the License.
</h4>
<h3 align="center">
<p>最先進的預訓練模型,專為推理與訓練而生</p>
<p>為 Jax、PyTorch 以及 TensorFlow 打造的先進自然語言處理函式庫</p>
</h3>
<h3 align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
</h3>
Transformers 是一個為最先進的機器學習模型(涵蓋文字、電腦視覺、音訊、影片及多模態)提供推理和訓練支援的模型定義框架
🤗 Transformers 提供了數以千計的預訓練模型,支援 100 多種語言的文本分類、資訊擷取、問答、摘要、翻譯、文本生成。它的宗旨是讓最先進的 NLP 技術人人易用
它將模型定義集中化,使得該定義在整個生態系中能夠達成共識。`transformers` 是貫穿各個框架的樞紐:如果一個模型定義受到支援,它將與大多數訓練框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorch-Lightning 等)、推理引擎(如 vLLM、SGLang、TGI 等)以及利用 `transformers` 模型定義的周邊建模函式庫(如 llama.cpp、mlx 等)相容
🤗 Transformers 提供了便於快速下載和使用的API讓你可以將預訓練模型用在給定文本、在你的資料集上微調然後經由 [model hub](https://huggingface.co/models) 與社群共享。同時,每個定義的 Python 模組架構均完全獨立,方便修改和快速研究實驗
我們致力於支援最新的頂尖模型,並透過使其模型定義變得簡單、可客製化且高效,來普及它們的應用
🤗 Transformers 支援三個最熱門的深度學習函式庫: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 並與之完美整合。你可以直接使用其中一個框架訓練你的模型,然後用另一個載入和推論
在 [Hugging Face Hub](https://huggingface.com/models) 上,有超過 100 萬個 Transformers [模型檢查點](https://huggingface.co/models?library=transformers&sort=trending) 供您使用。
## 線上Demo
立即探索 [Hub](https://huggingface.com/),尋找合適的模型,並使用 Transformers 幫助您快速上手
你可以直接在 [model hub](https://huggingface.co/models) 上測試大多數的模型。我們也提供了 [私有模型託管、模型版本管理以及推論API](https://huggingface.co/pricing)
這裡是一些範例:
- [用 BERT 做遮蓋填詞](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
- [用 Electra 做專有名詞辨識](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
- [用 RoBERTa 做自然語言推論](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
- [用 DistilBERT 做問答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
- [用 T5 做翻譯](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 團隊所打造,是一個文本生成的官方 demo。
## 如果你在尋找由 Hugging Face 團隊所提供的客製化支援服務
<a target="_blank" href="https://huggingface.co/support">
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
</a><br>
## 快速上手
我們為快速使用模型提供了 `pipeline` API。 Pipeline 包含了預訓練模型和對應的文本預處理。下面是一個快速使用 pipeline 去判斷正負面情緒的例子:
```python
>>> from transformers import pipeline
# 使用情緒分析 pipeline
>>> classifier = pipeline('sentiment-analysis')
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
```
第二行程式碼下載並快取 pipeline 使用的預訓練模型,而第三行程式碼則在給定的文本上進行了評估。這裡的答案“正面” (positive) 具有 99.97% 的信賴度。
許多的 NLP 任務都有隨選即用的預訓練 `pipeline`。例如,我們可以輕鬆地從給定文本中擷取問題答案:
``` python
>>> from transformers import pipeline
# 使用問答 pipeline
>>> question_answerer = pipeline('question-answering')
>>> question_answerer({
... 'question': 'What is the name of the repository ?',
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
... })
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
```
除了提供問題解答,預訓練模型還提供了對應的信賴度分數以及解答在 tokenized 後的文本中開始和結束的位置。你可以從[這個教學](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API支援的任務。
要在你的任務中下載和使用任何預訓練模型很簡單,只需三行程式碼。這裡是 PyTorch 版的範例:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
>>> outputs = model(**inputs)
```
這裡是對應的 TensorFlow 程式碼:
```python
>>> from transformers import AutoTokenizer, TFAutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
>>> outputs = model(**inputs)
```
Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換單一字串(比如上面的例子)或串列 (list)。它會輸出一個的字典 (dict) 讓你可以在下游程式碼裡使用或直接藉由 `**` 運算式傳給模型。
模型本身是一個常規的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取決於你的後端),可依常規方式使用。 [這個教學](https://huggingface.co/transformers/training.html)解釋了如何將這樣的模型整合到一般的 PyTorch 或 TensorFlow 訓練迴圈中,或是如何使用我們的 `Trainer` API 在一個新的資料集上快速進行微調。
## 為什麼要用 transformers
1. 便於使用的先進模型:
- NLU 和 NLG 上性能卓越
- 對教學和實作友好且低門檻
- 高度抽象,使用者只須學習 3 個類別
- 對所有模型使用的制式化API
1. 更低的運算成本,更少的碳排放:
- 研究人員可以分享已訓練的模型而非每次從頭開始訓練
- 工程師可以減少計算時間以及生產成本
- 數十種模型架構、兩千多個預訓練模型、100多種語言支援
1. 對於模型生命週期的每一個部分都面面俱到:
- 訓練先進的模型,只需 3 行程式碼
- 模型可以在不同深度學習框架之間任意轉換
- 為訓練、評估和生產選擇最適合的框架,並完美銜接
1. 為你的需求輕鬆客製化專屬模型和範例:
- 我們為每種模型架構提供了多個範例來重現原論文結果
- 一致的模型內部架構
- 模型檔案可單獨使用,便於修改和快速實驗
## 什麼情況下我不該用 transformers
- 本函式庫並不是模組化的神經網絡工具箱。模型文件中的程式碼並未做額外的抽象封裝,以便研究人員快速地翻閱及修改程式碼,而不會深陷複雜的類別包裝之中。
- `Trainer` API 並非相容任何模型,它只為本函式庫中的模型最佳化。對於一般的機器學習用途,請使用其他函式庫。
- 儘管我們已盡力而為,[examples 目錄](https://github.com/huggingface/transformers/tree/main/examples)中的腳本也僅為範例而已。對於特定問題,它們並不一定隨選即用,可能需要修改幾行程式碼以符合需求。
## 安裝
Transformers 支援 Python 3.9+ 和 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
### 使用 pip
使用 [venv](https://docs.python.org/3/library/venv.html) 或基於 Rust 的高速 Python 套件及專案管理器 [uv](https://docs.astral.sh/uv/) 來建立並啟用虛擬環境
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下經過測試
```py
# venv
python -m venv .my-env
source .my-env/bin/activate
# uv
uv venv .my-env
source .my-env/bin/activate
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
首先,用你打算使用的版本的 Python 創建一個虛擬環境並進入。
然後,你需要安裝 Flax、PyTorch 或 TensorFlow 其中之一。對於該如何在你使用的平台上安裝這些框架,請參閱 [TensorFlow 安裝頁面](https://www.tensorflow.org/install/), [PyTorch 安裝頁面](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安裝頁面](https://github.com/google/flax#quick-install)。
當其中一個後端安裝成功後,🤗 Transformers 可依此安裝:
```bash
pip install transformers
```
在您的虛擬環境中安裝 Transformers
如果你想要試試範例或者想在正式發布前使用最新開發中的程式碼,你必須[從原始碼安裝](https://huggingface.co/docs/transformers/installation#installing-from-source)
```py
# pip
pip install "transformers[torch]"
### 使用 conda
# uv
uv pip install "transformers[torch]"
🤗 Transformers 可以藉由 conda 依此安裝:
```shell script
conda install conda-forge::transformers
```
如果您想使用函式庫的最新變更或有興趣參與貢獻,可以從原始碼安裝 Transformers。然而*最新*版本可能不穩定。如果您遇到任何錯誤,歡迎隨時提交一個 [issue](https://github.com/huggingface/transformers/issues)
> **_筆記:_** 從 `huggingface` 頻道安裝 `transformers` 已被淘汰
```shell
git clone https://github.com/huggingface/transformers.git
cd transformers
要藉由 conda 安裝 Flax、PyTorch 或 TensorFlow 其中之一,請參閱它們各自安裝頁面的說明。
# pip
pip install '.[torch]'
## 模型架構
# uv
uv pip install '.[torch]'
```
**🤗 Transformers 支援的[所有的模型檢查點](https://huggingface.co/models)**,由[使用者](https://huggingface.co/users)和[組織](https://huggingface.co/organizations)上傳,均與 huggingface.co [model hub](https://huggingface.co) 完美結合。
## 快速入門
目前的檢查點數量: ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen)
透過 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 快速開始使用 Transformers。`Pipeline` 是一個高階的推理類別,支援文字、音訊、視覺和多模態任務。它負責處理輸入資料的預處理,並回傳適當的輸出。
🤗 Transformers 目前支援以下的架構: 模型概覽請參閱[這裡](https://huggingface.co/docs/transformers/model_summary).
實例化一個 pipeline 並指定用於文字生成的模型。該模型會被下載並快取,方便您之後輕鬆複用。最後,傳入一些文字來提示模型
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)
```py
from transformers import pipeline
這些實作均已於多個資料集測試(請參閱範例腳本)並應與原版實作表現相當。你可以在範例文件的[此節](https://huggingface.co/docs/transformers/examples)中了解實作的細節。
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
pipeline("the secret to baking a really good cake is ")
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
```
與模型進行聊天,使用模式是相同的。唯一的區別是您需要建構一個您與系統之間的聊天歷史(作為 `Pipeline` 的輸入)。
## 了解更多
> [!TIP]
> 你也可以直接在命令列中與模型聊天。
> ```shell
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
> ```
```py
import torch
from transformers import pipeline
chat = [
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
]
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
response = pipeline(chat, max_new_tokens=512)
print(response[0]["generated_text"][-1]["content"])
```
展開下面的範例,查看 `Pipeline` 如何在不同模態和任務上運作。
<details>
<summary>自動語音辨識</summary>
```py
from transformers import pipeline
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
```
</details>
<details>
<summary>影像分類</summary>
<h3 align="center">
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
</h3>
```py
from transformers import pipeline
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
[{'label': 'macaw', 'score': 0.997848391532898},
{'label': 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
'score': 0.0016551691805943847},
{'label': 'lorikeet', 'score': 0.00018523589824326336},
{'label': 'African grey, African gray, Psittacus erithacus',
'score': 7.85409429227002e-05},
{'label': 'quail', 'score': 5.502637941390276e-05}]
```
</details>
<details>
<summary>視覺問答</summary>
<h3 align="center">
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
</h3>
```py
from transformers import pipeline
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
pipeline(
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
question="What is in the image?",
)
[{'answer': 'statue of liberty'}]
```
</details>
## 為什麼我應該使用 Transformers
1. 易於使用的最先進模型:
* 在自然語言理解與生成、電腦視覺、音訊、影片和多模態任務上表現卓越。
* 為研究人員、工程師與開發者提供了低門檻的入門途徑。
* 面向使用者的抽象層級少,只需學習三個核心類別。
* 為所有預訓練模型提供了統一的 API 介面。
2. 更低的運算成本,更小的碳足跡:
* 分享訓練好的模型,而不是從零開始訓練。
* 減少運算時間和生產成本。
* 擁有數十種模型架構和超過100萬個橫跨所有模態的預訓練檢查點。
3. 為模型的每個生命週期階段選擇合適的框架:
* 僅用3行程式碼即可訓練最先進的模型。
* 在PyTorch/JAX/TF2.0框架之間輕鬆切換單一模型。
* 為訓練、評估和生產選擇最合適的框架。
4. 輕鬆根據您的需求客製化模型或範例:
* 我們為每個架構提供了範例,以重現其原作者發表的結果。
* 模型內部結構盡可能保持一致地暴露給使用者。
* 模型檔案可以獨立於函式庫使用,便於快速實驗。
<a target="_blank" href="https://huggingface.co/enterprise">
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
</a><br>
## 為什麼我不應該使用 Transformers
- 本函式庫並非一個用於建構神經網路的模組化工具箱。模型檔案中的程式碼為了讓研究人員能快速在模型上迭代,而沒有進行過度的抽象重構,避免了深入額外的抽象層/檔案。
- 訓練 API 針對 Transformers 提供的 PyTorch 模型進行了最佳化。對於通用的機器學習迴圈,您應該使用像 [Accelerate](https://huggingface.co/docs/accelerate) 這樣的其他函式庫。
- [範例指令稿](https://github.com/huggingface/transformers/tree/main/examples)僅僅是*範例*。它們不一定能在您的特定用例上開箱即用,您可能需要修改程式碼才能使其正常運作。
## 100個使用 Transformers 的專案
Transformers 不僅僅是一個使用預訓練模型的工具包,它還是一個圍繞它和 Hugging Face Hub 建構的專案社群。我們希望 Transformers 能夠賦能開發者、研究人員、學生、教授、工程師以及其他任何人,去建構他們夢想中的專案。
為了慶祝 Transformers 獲得 10 萬顆星標,我們希望透過 [awesome-transformers](./awesome-transformers.md) 頁面來聚焦社群該頁面列出了100個基於 Transformers 建構的精彩專案。
如果您擁有或使用一個您認為應該被列入其中的專案,請隨時提交 PR 將其加入!
## 範例模型
您可以在我們大多數模型的 [Hub 模型頁面](https://huggingface.co/models) 上直接進行測試。
展開下面的每個模態,查看一些用於不同用例的範例模型。
<details>
<summary>音訊</summary>
- Audio classification with [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo)
- Automatic speech recognition with [Moonshine](https://huggingface.co/UsefulSensors/moonshine)
- Keyword spotting with [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks)
- Speech to speech generation with [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16)
- Text to audio with [MusicGen](https://huggingface.co/facebook/musicgen-large)
- Text to speech with [Bark](https://huggingface.co/suno/bark)
</details>
<details>
<summary>電腦視覺</summary>
- Automatic mask generation with [SAM](https://huggingface.co/facebook/sam-vit-base)
- Depth estimation with [DepthPro](https://huggingface.co/apple/DepthPro-hf)
- Image classification with [DINO v2](https://huggingface.co/facebook/dinov2-base)
- Keypoint detection with [SuperPoint](https://huggingface.co/magic-leap-community/superpoint)
- Keypoint matching with [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor)
- Object detection with [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd)
- Pose Estimation with [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple)
- Universal segmentation with [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large)
- Video classification with [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large)
</details>
<details>
<summary>多模態</summary>
- Audio or text to text with [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B)
- Document question answering with [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base)
- Image or text to text with [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
- Image captioning [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b)
- OCR-based document understanding with [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
- Table question answering with [TAPAS](https://huggingface.co/google/tapas-base)
- Unified multimodal understanding and generation with [Emu3](https://huggingface.co/BAAI/Emu3-Gen)
- Vision to text with [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf)
- Visual question answering with [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
- Visual referring expression segmentation with [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224)
</details>
<details>
<summary>自然語言處理 (NLP)</summary>
- Masked word completion with [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base)
- Named entity recognition with [Gemma](https://huggingface.co/google/gemma-2-2b)
- Question answering with [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
- Summarization with [BART](https://huggingface.co/facebook/bart-large-cnn)
- Translation with [T5](https://huggingface.co/google-t5/t5-base)
- Text generation with [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B)
- Text classification with [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B)
</details>
| 章節 | 描述 |
|-|-|
| [文件](https://huggingface.co/transformers/) | 完整的 API 文件和教學 |
| [任務概覽](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支援的任務 |
| [預處理教學](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 來為模型準備資料 |
| [訓練和微調](https://huggingface.co/docs/transformers/training) | 使用 PyTorch/TensorFlow 的內建的訓練方式或於 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
| [快速上手:微調和範例腳本](https://github.com/huggingface/transformers/tree/main/examples) | 為各種任務提供的範例腳本 |
| [模型分享和上傳](https://huggingface.co/docs/transformers/model_sharing) | 上傳並與社群分享你微調的模型 |
| [遷移](https://huggingface.co/docs/transformers/migration) | 從 `pytorch-transformers` 或 `pytorch-pretrained-bert` 遷移到 🤗 Transformers |
## 引用
現在我們有一篇可供您引用的關於 🤗 Transformers 函式庫的 [論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/)
我們已將此函式庫的[論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/)正式發表。如果你使用了 🤗 Transformers 函式庫,可以引用
```bibtex
@inproceedings{wolf-etal-2020-transformers,
title = "Transformers: State-of-the-Art Natural Language Processing",
@ -324,4 +285,4 @@ Transformers 不僅僅是一個使用預訓練模型的工具包,它還是一
url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6",
pages = "38--45"
}
```
```

View File

@ -137,7 +137,7 @@ _deps = [
"psutil",
"pyyaml>=5.1",
"pydantic>=2",
"pytest>=7.2.0,<9.0.0",
"pytest>=7.2.0",
"pytest-asyncio>=1.2.0",
"pytest-rerunfailures<16.0",
"pytest-timeout",

View File

@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_dtype_to_str(serializable_config_dict)
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_dtype_to_str(output)

View File

@ -1,136 +0,0 @@
# coding=utf-8
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
from .utils import is_torch_available
if is_torch_available():
import torch
def _build_checkpoint_conversion_mapping():
mapping = {
"mixtral": [
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w1.weight",
"block_sparse_moe.experts.*.w3.weight",
], # you give me a list of 2 keys, I collect a list of a list of tensors
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w2.weight",
],
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
# WeightConverter(
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
# "self_attn.qkv_proj",
# operations=[Concatenate(dim=0)], # more like stack?
# ),
WeightConverter("*.block_sparse_moe.", "*.mlp."),
],
"qwen2_moe": [
WeightConverter(
source_keys=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_keys="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_keys=["mlp.experts.*.down_proj.weight"],
target_keys="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"legacy": [
WeightConverter(
source_keys="LayerNorm.gamma",
target_keys="LayerNorm.weight",
),
WeightConverter(
source_keys="LayerNorm.beta",
target_keys="LayerNorm.bias",
),
],
}
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightConverter(
source_keys="weight_g",
target_keys="parametrizations.weight.original0",
),
WeightConverter(
source_keys="weight_v",
target_keys="parametrizations.weight.original1",
),
]
else:
mapping["legacy"] += [
WeightConverter(
source_keys="parametrizations.weight.original0",
target_keys="weight_g",
),
WeightConverter(
source_keys="parametrizations.weight.original1",
target_keys="weight_v",
),
]
mapping["phimoe"] = mapping["mixtral"].copy()
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dot1"] = mapping["qwen2_moe"].copy()
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["jamba"] = mapping["qwen2_moe"].copy()
mapping["lfm2_moe"] = mapping["mixtral"].copy()
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
return mapping
_checkpoint_conversion_mapping_cache = None
def get_checkpoint_conversion_mapping(model_type):
global _checkpoint_conversion_mapping_cache
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type, None))

View File

@ -1,631 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core helpers for loading model checkpoints."""
from __future__ import annotations
import itertools
import os
import re
from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping, MutableSet, Sequence
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from .integrations.accelerate import offload_weight
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
from .utils import is_torch_greater_or_equal, logging
_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor
if TYPE_CHECKING:
from .modeling_utils import PreTrainedModel
from .quantizers import HfQuantizer
logger = logging.get_logger(__name__)
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
logger = logging.get_logger(__name__)
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
"""
star = r"(\d+)" if digits_only else r"(.+)"
return glob.replace(r"\*", star)
def build_glob_alt(
globs: list[str],
) -> tuple[re.Pattern, dict[str, str]]:
r"""
Build one compiled regex alternation with a named group per glob. This allows to run a single
re.match and get the correct group name to finally get which pattern matched.
Returns (compiled_regex, name->glob map).
Example:
```py
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
>>> print(reg)
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
>>> print(map_)
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
>>> print(match_.lastgroup)
'g0'
>>> print(map_[match_.lastgroup])
mlp.*.w1
```
"""
name_map: dict[str, str] = {}
parts: list[str] = []
for i, g in enumerate(globs):
name = f"g{i}"
name_map[name] = g
pat_src = _glob_to_regex_src(g)
prefix_src = ""
if pat_src.startswith("*"):
prefix_src = "."
elif not pat_src.startswith(r"\^") and not pat_src.startswith(r".*"):
prefix_src = ".*"
parts.append(f"(?P<{name}>{prefix_src}{pat_src}.*)")
alt_src = "|".join(parts).replace("\\^", "^").replace("\\.", r"\.")
try:
reg = re.compile(alt_src)
except re.error as e:
logger.error(f"Error compiling regex for alternation: {alt_src}")
raise e
return reg, name_map
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
"""
Match the key against the alternation; return the original glob string that matched.
"""
m = alt.match(key)
if not m:
return None
return name_map.get(m.lastgroup)
class ConversionOps:
"""Base class for weight conversion operations."""
# The inverse operation class, will be used when saving the checkpoint
reverse_op: type[ConversionOps]
@abstractmethod
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
if chunks is None and sizes is None:
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
if chunks is not None and chunks <= 0:
raise ValueError("`chunks` must be a strictly positive integer.")
self.dim = dim
self.chunks = chunks
self.sizes = list(sizes) if sizes is not None else None
self.reverse_op = Concatenate
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
# chunk requires a single tensor input
if len(value) != 1 or len(value[0]) != 1:
raise ValueError("Chunk operation requires a single tensor input.")
return list(torch.chunk(value[0][0], self.chunks, dim=self.dim))
class Concatenate(ConversionOps):
"""Concatenate tensors along `dim` using a reusable buffer."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0):
self.dim = dim
self.reverse_op = Chunk
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
if isinstance(value[0], list):
value = [v[0] for v in value]
tensors = value
if not tensors:
raise ValueError("Fuse requires at least one tensor to concatenate.")
return torch.cat(tuple(tensors), dim=self.dim)
class MergeModulelist(Concatenate):
"""
Merge a list of tensors into a single tensor along the first dimension.
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
"""
def __init__(self, dim: int = 0):
super().__init__(dim=dim)
self.reverse_op = SplitModulelist
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
merged = []
for group in value:
if not isinstance(group, Sequence) or len(group) == 0:
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
group = [k for k in group if k.ndim]
merged.append(torch.stack(group, dim=self.dim))
return merged
class SplitModulelist(ConversionOps):
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
self.sizes = [list(sub) for sub in sizes]
self.dim = dim
self.reverse_op = MergeModulelist
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
if not isinstance(value, Sequence):
raise TypeError("SplitModulelist expects a sequence of tensors.")
if len(value) != len(self.sizes):
raise ValueError("Number of tensors does not match the provided split specifications.")
result: list[list[torch.Tensor]] = []
for tensor, split_sizes in zip(value, self.sizes):
if not isinstance(tensor, torch.Tensor):
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
splits = torch.split(tensor, split_sizes, dim=self.dim)
result.append(list(splits))
return result
class PermuteForRope(ConversionOps):
"""
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
"""
def __init__(self):
pass
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
dim1, dim2 = tensor.shape
n_heads = self.config.getattr("num_attention_heads", 1)
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
return tensor
@torch.no_grad
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
self.config = config
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
return out
@dataclass(slots=True)
class WeightConverter:
r"""
A weight convert that acts on a pattern of source keys.
The keys need to be collected based on the target keys.
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
`model.layers.*.experts.*` -> it will act on all of them
{"model.layers.*.experts.*": []}
but
`experts.*.mlp` will be layer specific.
{"model.layers.1.experts.*": [], }
- source_keys: str | list[str] (wildcards '*' match digits)
- target_keys: str | list[str] | None
- distributed_operation / operations / quantization_operations are ALWAYS lists.
TODO: for BNB we need to collect model.weight.quant_state_keys
"""
source_keys: Union[str, list[str]]
target_keys: Optional[Union[str, list[str]]] = None
operations: list[ConversionOps] = field(default_factory=list, repr=False)
distributed_operation: Optional[TensorParallelLayer] = None
quantization_operation: Optional[ConversionOps] = None
def __post_init__(self):
if not isinstance(self.source_keys, list):
self.source_keys = [self.source_keys]
targets_were_none = False
if not isinstance(self.target_keys, list):
if self.target_keys is None:
self.target_keys = list(self.source_keys)
targets_were_none = True
else:
self.target_keys = [self.target_keys]
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
raise ValueError(
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
)
@dataclass(slots=True)
class ConversionEntry:
weight_converter: WeightConverter
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
def _materialize_copy(tensor, dtype=None):
tensor = tensor[...]
if dtype is not None:
tensor = tensor.to(dtype)
return tensor
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
def _job():
return _materialize_copy(tensor, dtype)
return thread_pool.submit(_job)
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
def _job():
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
return thread_pool.submit(_job)
def dot_natural_key(s: str):
parts = s.split(".")
for i, p in enumerate(parts):
# whole-segment digits -> int; otherwise leave as str
if p.isdigit():
parts[i] = int(p)
return parts
@contextmanager
def log_to_misc(
full_param_name: str,
misc: MutableMapping[str, str],
extras: Any = None,
op: Union[list[ConversionOps], ConversionOps, None] = None,
):
# A simple helper to handle errors with contextual messages.
try:
yield
except Exception as e:
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
if curr_op is None:
return None
if isinstance(curr_op, (list, tuple, set)):
names = [o.__class__.__name__ for o in curr_op if o is not None]
if not names:
return None
return ", ".join(names)
return curr_op.__class__.__name__
op_name = _format_op_name(op)
if isinstance(extras, tuple) and len(extras) == 2:
values, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[full_param_name] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
misc[full_param_name] = f"{e}\nError{suffix} when processing parameter {extras}"
elif extras is None and op_name:
misc[full_param_name] = f"{op_name}: {e}"
else:
misc[full_param_name] = f"{extras} |Error: {e}"
raise SkipLayer()
def set_param_for_module(
model: PreTrainedModel,
full_param_name: str,
param_value: torch.Tensor,
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
):
with log_to_misc(full_param_name, misc, full_param_name):
module_path, _, param_name = full_param_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
ref = getattr(module_obj, param_name)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation is not None:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
getattr(distributed_operation, "shard", Replicate()),
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
if not use_dtensor:
# we convert to local
param_value = param_value.to_local()
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
# Remove from missing keys (it's either mismatched, or all good)
missing_keys.discard(full_param_name)
if ref is not None and ref.shape != param_value.shape:
mismatch_keys.add((full_param_name, param_value.shape, ref.shape))
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
else:
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
setattr(module_obj, param_name, param_value)
class SkipLayer(Exception):
"""Control-flow sentinel: abort processing of the current layer only."""
pass
def convert_and_load_state_dict_in_model(
model: PreTrainedModel,
state_dict: dict[str, Any],
weight_mapping: dict[str, WeightConverter] | None,
tp_plan: dict[str, str] | None,
quantizer: HfQuantizer | None,
dtype: torch.dtype | None = None,
device_map: dict | None = None,
dtype_plan: dict | None = None,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
disk_offload_index: dict | None = None,
disk_offload_folder: str | None = None,
):
"""
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
collecting tensors per *layer instance* (the concrete indices captured from '*').
"""
prefix = model.base_model_prefix
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
device_map = device_map or {} # {exact_target_key: device}
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
meta_model_state_dict = model.state_dict()
missing_keys = set(meta_model_state_dict.keys())
misc = {}
mismatch_keys = set()
unexpected_keys = set()
# Global thread_pool
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
# 1. Create the conversion entries
by_conversion_pattern: dict[str, ConversionEntry] = {}
for original_key, tensor in state_dict:
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
if matched_pattern is not None:
converter = source_to_target[matched_pattern] # TODO make sure its the ref
sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key)
entry_key = "|".join(converter.target_keys)
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
converter_key = sub_with_extractor(matched_pattern)
else:
converter = WeightConverter(original_key)
converter_key = entry_key = target_key = original_key
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
_dtype = dtype
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
for t in target_key.split("|"):
if t.startswith(prefix) and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", t, count=1)) is not None:
t = re.sub(f"^{prefix}.", "", t, count=1)
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
t = f"{prefix}.{t}"
new_target_key.append(t)
empty_param = meta_model_state_dict.get(t)
# If it does not exist, it's unexpected
if empty_param is None:
unexpected_keys.add(t)
continue
if quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize
converter.quantization_operation = Fp8Quantize() # TODO support other methods
else:
raise ValueError("This quantization method is gonna be supported SOOOON")
else:
_dtype = dtype
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
if matched_dtype_pattern is not None:
_dtype = dtype_plan[matched_dtype_pattern]
elif empty_param.dtype != _dtype:
_dtype = empty_param.dtype
first_target_key = new_target_key[0]
target_key = "|".join(new_target_key)
future = None
if device_mesh:
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
empty_param = meta_model_state_dict.get(first_target_key)
if getattr(converter, "distributed_operation", {}) is None:
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
converter.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
future = spawn_tp_materialize(
thread_pool,
tensor,
_dtype,
converter.distributed_operation,
shard_index,
)
if future is None:
future = spawn_materialize(thread_pool, tensor, _dtype)
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
# 2. Actually convert the ckpt
inverse_converters = {}
keys = list(by_conversion_pattern.keys())
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
for key in keys[::-1]: # revert to process simple keys first
group = by_conversion_pattern.pop(key)
converter = group.weight_converter
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
for full_param_name, tensors_for_this_layer in group.collected_tensors.items():
pbar.update(1)
pbar.set_postfix({"Materializing param": full_param_name})
pbar.refresh()
concrete_target_keys = full_param_name.split("|")
try:
if bool(set(concrete_target_keys) - unexpected_keys):
with log_to_misc(full_param_name, misc):
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
for op in operations:
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
values = op.convert(values, model.config)
values = [values] if not isinstance(values, list) else values
with log_to_misc(full_param_name, misc, (values, concrete_target_keys), operations):
realized_value = {
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
}
for k in list(realized_value.keys()).copy():
if op := converter.quantization_operation:
with log_to_misc(full_param_name, misc, op=op):
realized_value.update(
op.convert(
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
)
)
for k, output_value in realized_value.items():
output_value = output_value[0] if isinstance(output_value, list) else output_value
for src in converter.source_keys: # what should happen to k when we meet k at saving
inverse_converters[k] = {src: converter}
param_device = device_map[re.search(device_map_regex, k).group()]
# Offloading support
if param_device == "disk":
missing_keys.discard(k)
# If not already offloaded, or if we applied any special Operation, we need to re-save
if k not in disk_offload_index or len(operations) > 0:
disk_offload_index = offload_weight(
output_value, k, disk_offload_folder, disk_offload_index
)
else:
set_param_for_module(
model,
k,
output_value,
mismatch_keys,
missing_keys,
misc,
converter.distributed_operation,
)
except SkipLayer:
continue
del group
model.inverse_converters = inverse_converters
thread_pool.shutdown(wait=False)
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
# TODO this is not done yet!
def revert_weight_conversion(model, state_dict):
mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava.
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
original_state_dict = {}
for key, value in state_dict.items():
for pattern, inverse_converter in reverse_key_mapping:
# TODO FIXME you name it
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict

View File

@ -723,7 +723,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
if self.mask_replace_prob < 1:
warnings.warn(
"Random token replacement is not supported with whole word masking. "
"Random token replacement is not supported with whole word masking.",
"Setting mask_replace_prob to 1.",
)
self.mask_replace_prob = 1

View File

@ -82,7 +82,7 @@ class GlueDataset(Dataset):
cache_dir: Optional[str] = None,
):
warnings.warn(
"This dataset will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
FutureWarning,

View File

@ -21,7 +21,7 @@ if is_sklearn_available():
DEPRECATION_WARNING = (
"This metric will be removed from the library soon, metrics should be handled with the Hugging Face Evaluate "
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
)

View File

@ -28,7 +28,7 @@ from .utils import DataProcessor, InputExample, InputFeatures
logger = logging.get_logger(__name__)
DEPRECATION_WARNING = (
"This {0} will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
)

View File

@ -47,7 +47,7 @@ deps = {
"psutil": "psutil",
"pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic>=2",
"pytest": "pytest>=7.2.0,<9.0.0",
"pytest": "pytest>=7.2.0",
"pytest-asyncio": "pytest-asyncio>=1.2.0",
"pytest-rerunfailures": "pytest-rerunfailures<16.0",
"pytest-timeout": "pytest-timeout",

View File

@ -39,7 +39,6 @@ from .utils import (
is_torch_dtype,
logging,
requires_backends,
safe_load_json_file,
)
from .utils.hub import cached_file
@ -428,42 +427,35 @@ class FeatureExtractionMixin(PushToHubMixin):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
resolved_processor_file = None
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
resolved_processor_file = None
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
else:
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_processor_file = cached_file(
pretrained_model_name_or_path,
filename=PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_feature_extractor_file = cached_file(
pretrained_model_name_or_path,
filename=feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_feature_extractor_files = [
resolved_file
for filename in [feature_extractor_file, PROCESSOR_NAME]
if (
resolved_file := cached_file(
pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
subfolder=subfolder,
token=token,
user_agent=user_agent,
revision=revision,
_raise_exceptions_for_missing_entries=False,
)
)
is not None
]
resolved_feature_extractor_file = resolved_feature_extractor_files[0]
except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
@ -477,24 +469,19 @@ class FeatureExtractionMixin(PushToHubMixin):
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
)
# Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
# We are downloading both configs because almost all models have a `processor_config.json` but
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
feature_extractor_dict = None
if resolved_processor_file is not None:
processor_dict = safe_load_json_file(resolved_processor_file)
if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
try:
# Load feature_extractor dict
with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
if "audio_processor" in feature_extractor_dict:
feature_extractor_dict = feature_extractor_dict["audio_processor"]
else:
feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
if feature_extractor_dict is None:
except json.JSONDecodeError:
raise OSError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {feature_extractor_file} file"
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)
if is_local:

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from math import floor, gcd, sqrt
from typing import Optional
@ -20,8 +21,8 @@ import torch
from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...utils.metrics import attach_tracer, traced
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import RequestState, get_device_and_memory_breakdown, logger
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import get_device_and_memory_breakdown, logger
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
@ -31,7 +32,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
- All groups have the same number of layers
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
"""
# If the config has no layer_type attribute, it means all layers are the same attention type
layer_types = getattr(config, "layer_types", None)
@ -115,6 +116,7 @@ class PagedAttentionCache:
for the sliding-attention group, although it is not needed.
"""
# TODO: this init is quite long, maybe a refactor is in order
def __init__(
self,
config: PreTrainedConfig,
@ -122,10 +124,8 @@ class PagedAttentionCache:
device: torch.device,
dtype: torch.dtype = torch.float16,
tp_size: Optional[int] = None,
allow_prefix_sharing: bool = True,
) -> None:
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
only full attention layers.
"""Initialize a paged attention cache for efficient memory usage.
Args:
config: Model configuration
@ -133,7 +133,6 @@ class PagedAttentionCache:
device: Device for the cache tensors
dtype: Data type of the cache
tp_size: Tensor parallelism size
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
"""
self.config = config
self.dtype = dtype
@ -174,12 +173,10 @@ class PagedAttentionCache:
page_size = self.head_dim * self.num_key_value_heads
if "flash" in self.config._attn_implementation:
num_attention_masks = 0 # only used to compute the default memory footprint args
elif "sliding_attention" in group_types:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2
num_attention_masks = 1 # only used to compute the default meme args
else:
num_attention_masks = 1
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2 if "sliding_attention" in group_types else 1
memory_handler = PagedAttentionMemoryHandler(
block_size=self.block_size,
@ -192,9 +189,7 @@ class PagedAttentionCache:
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
num_blocks=getattr(generation_config, "num_blocks", None),
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
max_memory_percent=getattr(
generation_config, "max_memory", 0.8
), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
cache_dtype=self.dtype,
)
@ -221,6 +216,7 @@ class PagedAttentionCache:
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
# Block management data structures
self._free_blocks = deque(range(num_blocks))
self.group_cache_managers: list[CacheAllocator] = []
for i, group_type in enumerate(group_types):
if group_type == "full_attention":
@ -231,19 +227,13 @@ class PagedAttentionCache:
raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)
# We only use prefix sharing if the whole model has only full attention layers
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
self.blocks_to_complete: dict[str, int] = {}
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
@traced
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
max_allocated = 0
for cm in self.group_cache_managers:
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
if allocated is None:
return None
max_allocated = max(max_allocated, allocated)
@ -254,11 +244,11 @@ class PagedAttentionCache:
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
by the cache managers."""
for cm in self.group_cache_managers:
cm.free_blocks(request_id, self._block_manager)
cm.free_blocks(request_id, self._free_blocks)
def get_num_free_blocks(self) -> int:
"""Get the current number of unallocated blocks available for new requests."""
return self._block_manager.num_free_blocks
return len(self._free_blocks)
@traced
def extend_read_indices(
@ -345,44 +335,6 @@ class PagedAttentionCache:
# Return the new KV values
return key_states_with_cache, value_states_with_cache
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
"""Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
that match. If no prefix match is found, we return 0."""
current_hash = None
allocated_blocks = []
for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
block_id = self._block_manager._hash_to_id.get(current_hash)
if block_id is not None:
allocated_blocks.append(block_id)
self._block_manager.increase_ref_count(block_id)
else:
break
# If we found a matching prefix, we reference the blocks in the request
if allocated_blocks:
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
cm = self.group_cache_managers[0]
cm.block_table[request_id] = allocated_blocks
prefix_length = len(allocated_blocks) * self.block_size
self._total_prefix_length += prefix_length
return prefix_length
def mark_blocks_as_complete(self, state: RequestState) -> None:
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
a no-op."""
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
if num_complete_blocks == 0:
return None
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
self._block_manager.mark_blocks_as_complete(
num_complete_blocks=num_complete_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.full_prompt_ids + state.static_outputs),
)
# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
@ -462,7 +414,7 @@ class PagedAttentionMemoryHandler:
self,
num_blocks: Optional[int] = None,
max_batch_tokens: Optional[int] = None,
max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
max_memory_percent: float = 0.9,
cache_dtype: torch.dtype = torch.float16,
) -> tuple[int, int]:
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
@ -502,7 +454,7 @@ class PagedAttentionMemoryHandler:
def compute_num_blocks_and_max_batch_tokens(
self,
max_memory_percent: float,
max_memory_percent: float = 0.9,
cache_dtype: torch.dtype = torch.float16,
m: float = 0.01,
) -> tuple[int, int]:
@ -517,8 +469,6 @@ class PagedAttentionMemoryHandler:
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
])
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
"""
cache_memory = self.get_available_memory(max_memory_percent)
logger.info(f"Cache memory: {cache_memory}")
@ -530,16 +480,11 @@ class PagedAttentionMemoryHandler:
c = -cache_memory
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
if self.num_attention_masks == 0:
greatest_solution = -c / b
# Otherwise, we solve the quadratic equation
else:
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
# Compute discriminant and greatest solution
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
if greatest_solution < 0:
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
@ -558,7 +503,7 @@ class PagedAttentionMemoryHandler:
def compute_max_batch_tokens(
self,
num_blocks: int,
max_memory_percent: float,
max_memory_percent: float = 0.9,
cache_dtype: torch.dtype = torch.float16,
) -> int:
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
@ -586,7 +531,7 @@ class PagedAttentionMemoryHandler:
def compute_num_blocks(
self,
max_batch_tokens: int,
max_memory_percent: float,
max_memory_percent: float = 0.9,
cache_dtype: torch.dtype = torch.float16,
) -> int:
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:

View File

@ -14,211 +14,29 @@
# limitations under the License.
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from math import ceil
from typing import Optional, TypeVar
from typing import Optional
from .requests import logger
T = TypeVar("T")
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
index = len(xs) - 1
for x in xs[::-1]:
yield index, x
index -= 1
class Block:
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
its parent's hash (if there is a parent)."""
def __init__(self, id_: int, parent_id: int | None) -> None:
self.id: int = id_
self.parent_id: int | None = parent_id
self.hash: int | None = None
self.ref_count: int = 1
def __repr__(self) -> str:
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
@property
def is_complete(self) -> bool:
return self.hash is not None
class BlockManager:
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
referencing this block is stored as ref_count in the Block object.
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
be given as free blocks to new requests without any overhead.
- initialized: the block is complete and was used by one or more request that are finished. It contains KV cache
data and its hash is stored in the hash table. If a new request needs a block with the same hash, we increase
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
hash table.
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
it is in use.
"""
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
layers."""
self.num_blocks = num_blocks
self.block_size = block_size
self._uninit_block_ids = deque(range(num_blocks))
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
self._use_prefix_sharing = use_prefix_sharing
self._hash_to_id: dict[int, int] = {}
self._id_to_block: dict[int, Block] = {}
@property
def num_free_blocks(self) -> int:
"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
return len(self._uninit_block_ids) + len(self._init_block_ids)
def has_enough_free_blocks(self, n_blocks: int) -> bool:
"""Checks if there are enough free blocks to allocate the requested number of blocks (n_blocks). If there are
not enough uninitialized blocks, we uninitialize the required number of initialized blocks."""
# Exit early if there are enough uninitialized blocks
if len(self._uninit_block_ids) >= n_blocks:
return True
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
if len(self._init_block_ids) < block_to_unintialize:
return False
# Uninitialize the required amount of blocks
for _ in range(block_to_unintialize):
id_to_unintialize = self._init_block_ids.popitem()[0]
block = self._id_to_block[id_to_unintialize]
self._hash_to_id.pop(block.hash)
self._uninit_block_ids.append(id_to_unintialize)
return True
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
the parent block. If the manager cannot find enough free blocks, it returns None."""
if not self.has_enough_free_blocks(n_blocks):
return None
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
if self._use_prefix_sharing:
for block_id in allocated_block_ids:
block = Block(block_id, last_block_id)
self._id_to_block[block_id] = block
last_block_id = block_id
# In both cases, we return the allocated block ids
return allocated_block_ids
def increase_ref_count(self, block_id: int) -> None:
"""Increases the reference count of a given (block_id)."""
block = self._id_to_block[block_id]
block.ref_count += 1
if block.ref_count == 1:
self._init_block_ids.pop(block_id)
def decrease_ref_count(self, block_id: int) -> None:
"""Decreases the reference count of a given (block_id). If the reference count reaches 0, the block is no longer
in use, and becomes initialized (if it was complete) or uninitialized (if it was incomplete)."""
block = self._id_to_block[block_id]
block.ref_count -= 1
if block.ref_count == 0:
if block.is_complete:
self._init_block_ids[block_id] = None
else:
self._id_to_block.pop(block_id)
self._uninit_block_ids.append(block_id)
def free_blocks(self, blocks: list[int]) -> None:
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
blocks queue. Otherwise, their new state depends on whether they are complete."""
if self._use_prefix_sharing:
for block_id in blocks:
self.decrease_ref_count(block_id)
else:
self._uninit_block_ids.extend(blocks)
def mark_blocks_as_complete(
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
) -> None:
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
of (prompt_ids) is used to compute the hash of the new block."""
# Look for the first complete block, starting from the last block in the sequence
parent_hash = None
incomplete_blocks: list[Block] = []
for i, block_id in reverse_enumerate(allocated_blocks):
block = self._id_to_block[block_id]
if block.is_complete:
parent_hash = block.hash
break
incomplete_blocks.append((i, block))
# Now go through the incomplete blocks and updated them
new_parent_id = None
while incomplete_blocks:
i, block = incomplete_blocks.pop()
# If the parent id has been updated, we apply the change
if new_parent_id is not None:
block.parent_id = new_parent_id
new_parent_id = None
# If we have set the hash for all complete blocks, we can stop
if num_complete_blocks == 0:
break
# Otherwise, we compute the hash
num_complete_blocks -= 1
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
block.hash = self.compute_hash(parent_hash, tokens)
existing_block_id = self._hash_to_id.get(block.hash)
# If the block hash is already in the hash to id mapping, we reference the existing block instead
if existing_block_id is not None:
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
allocated_blocks[i] = existing_block_id
self._id_to_block[existing_block_id].ref_count += 1
new_parent_id = existing_block_id
self.free_blocks([block.id])
# Otherwise, we add the completed block to the hash table
else:
self._hash_to_id[block.hash] = block.id
# Update loop variables
parent_hash = block.hash
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
parent, the parent hash is None."""
return hash((parent_hash, tuple(tokens)))
class CacheAllocator(ABC):
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
_index: int
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
@abstractmethod
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocates (n_blocks) for a given (request_id) using the (block_manager). Returns the num of blocks allocated
if successful and None otherwise."""
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
otherwise."""
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
if request_id in self.block_table:
blocks_to_free = self.block_table.pop(request_id)
block_manager.free_blocks(blocks_to_free)
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
"""Frees all blocks associated with a request_id."""
if request_id in self._block_table:
blocks_to_free = self._block_table.pop(request_id)
free_blocks.extend(blocks_to_free)
else:
logger.warning(
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
@ -248,30 +66,23 @@ class FullAttentionCacheAllocator(CacheAllocator):
"""
self._index = index
self.block_size = block_size
self.block_table = {}
self._block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
allocated if successful and None otherwise. For group of full attention layers, we always allocate the number of
requested blocks."""
# Make sure the request_id is in the block table and get the first block id
if request_id not in self.block_table:
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
last_block_id = None
else:
last_block_id = self.block_table[request_id][-1]
# Actual allocation, return early if failed
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
if allocated_blocks is None:
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
if len(free_blocks) < n_blocks:
return None
self.block_table[request_id].extend(allocated_blocks)
if request_id not in self._block_table:
self._block_table[request_id] = []
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
return n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -286,7 +97,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
cache as a continuation of the existing cache for the same request."""
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -318,26 +129,25 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
self.block_size = block_size
self.sliding_window = sliding_window
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
self.block_table = {}
self._block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
allocated otherwise. For group of sliding window attention layers, we only allocate up to the point where we can
fit an entire sliding window in the cache tensor."""
if request_id not in self.block_table:
self.block_table[request_id] = []
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
entire sliding window in the cache tensor."""
if request_id not in self._block_table:
self._block_table[request_id] = []
# Early return if we are already at the max number of blocks per request
already_allocated = len(self.block_table[request_id])
already_allocated = len(self._block_table[request_id])
if already_allocated == self._max_blocks_per_request:
return 0
# Compute actual number of blocks to allocate
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
actual_n_blocks = after_allocation - already_allocated
# Classic allocation
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
if allocated_blocks is None:
if len(free_blocks) < actual_n_blocks:
return None
self.block_table[request_id].extend(allocated_blocks)
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
return actual_n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
@ -347,7 +157,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
which indicate where to store the new key or values indices."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window
@ -368,7 +178,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
the allocated physical cache, we start writing from the beginning of the physical cache again."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window
@ -391,3 +201,22 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
return "sliding_attention", seqlens_k
# TODO: test the impact of this
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
# # Retrieve the block table for the request and raise an error if it doesn't exist
# block_table = self._block_table.get(request_id)
# if block_table is None:
# raise ValueError(f"No block table found for request {request_id}")
# # Compute the physical indices
# physical_indices = []
# n_left = past_length
# for block_idx in block_table:
# block_physical_index = block_idx * self.block_size
# pages_used = min(self.block_size, n_left)
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
# n_left -= pages_used
# if n_left == 0:
# return physical_indices
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")

View File

@ -16,13 +16,12 @@
import queue
import threading
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from itertools import count
from math import ceil
from time import perf_counter
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
@ -447,7 +446,10 @@ class ContinuousBatchProcessor:
cumulative_seqlens_q = [0]
logits_indices = []
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
if isinstance(self.cumulative_seqlens_k, dict):
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
else:
cumulative_seqlens_k = [0]
read_index = [[] for _ in range(self.cache.num_groups)]
write_index = [[] for _ in range(self.cache.num_groups)]
@ -496,7 +498,10 @@ class ContinuousBatchProcessor:
self.metrics.record_kv_cache_memory_metrics(self.cache)
if logger.isEnabledFor(logging.DEBUG):
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
if isinstance(self.cumulative_seqlens_k, dict):
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
else:
ck = cumulative_seqlens_k[-1]
logger.debug(
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
@ -512,7 +517,7 @@ class ContinuousBatchProcessor:
read_index: list[list[int]],
write_index: list[list[int]],
cumulative_seqlens_q: list[int],
cumulative_seqlens_k: dict[str, list[int]],
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
logits_indices: list[int],
) -> None:
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
@ -556,7 +561,9 @@ class ContinuousBatchProcessor:
@traced
def _maybe_send_output(self, state: RequestState) -> None:
"""Send output to the queue based on streaming mode and request state."""
if state.streaming or state.status == RequestStatus.FINISHED:
if state.streaming:
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@traced
@ -564,27 +571,17 @@ class ContinuousBatchProcessor:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prompt_ids) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
state.prompt_ids = [token]
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
# We mark the completed blocks as such
self.cache.mark_blocks_as_complete(state)
if is_finished:
if state.update_with_token(token):
self.metrics.record_request_completion(state.created_time, state.request_id)
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
self._maybe_send_output(state)
# Otherwise, the request is still prefilling, but the prefill has been split
elif state.status == RequestStatus.PREFILLING_SPLIT:
self.cache.mark_blocks_as_complete(state)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
else:
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")
@ -729,7 +726,6 @@ class ContinuousBatchingManager:
max_queue_size: int = 0,
num_q_cuda_graphs: int = 0,
num_kv_cuda_graphs: int = 0,
allow_prefix_sharing: bool = True,
) -> None:
"""Initialize the continuous batching manager.
@ -739,7 +735,6 @@ class ContinuousBatchingManager:
max_queue_size: Maximum size of the request queue (0 = unlimited)
num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
"""
if "paged|" not in model.config._attn_implementation:
attn_implementation = f"paged|{model.config._attn_implementation}"
@ -772,8 +767,6 @@ class ContinuousBatchingManager:
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self._allow_prefix_sharing = allow_prefix_sharing
# If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
self.use_cuda_graph = True
@ -806,6 +799,7 @@ class ContinuousBatchingManager:
logger.warning("Manager thread is already running.")
return
self._result_queue = queue.Queue()
self._generation_thread = threading.Thread(target=self._run_generation_loop)
self._generation_thread.start()
@ -820,16 +814,6 @@ class ContinuousBatchingManager:
block: Whether to wait for the thread to stop
timeout: Maximum time to wait for the thread to stop
"""
if self.batch_processor is None:
logger.warning("\nBatch processor was not initialized.")
else:
if self.batch_processor.cache.use_prefix_sharing:
logger.warning(
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
)
else:
logger.warning("\nPrefix sharing was off.")
if self._generation_thread is None:
logger.warning("Manager not started.")
return
@ -842,8 +826,6 @@ class ContinuousBatchingManager:
if block:
self.join(stop_trigger_time, timeout)
self.batch_processor = None
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
"""Wait for the background thread to finish.
@ -955,6 +937,20 @@ class ContinuousBatchingManager:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
@traced
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
stream = torch.cuda.Stream(device=self.model.device)
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
# Warmup the model with a dummy forward pass
self._generation_step(batch_processor)
torch.cuda.current_stream().wait_stream(stream)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, stream=stream):
self._generation_step(batch_processor)
@traced
# @torch.compile
def _generation_step(self) -> None:
"""Perform a single generation step. This is cuda graphed"""
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
@ -970,7 +966,6 @@ class ContinuousBatchingManager:
self.model.device,
self.model.dtype,
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
allow_prefix_sharing=self._allow_prefix_sharing,
)
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
@ -1062,15 +1057,6 @@ class ContinuousBatchingManager:
class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""
@contextmanager
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
manager = self.init_continuous_batching(**kwargs)
manager.start()
try:
yield manager
finally:
manager.stop(block=True)
def init_continuous_batching(
self,
generation_config: Optional[GenerationConfig] = None,
@ -1078,7 +1064,6 @@ class ContinuousMixin:
max_queue_size: int = 0,
num_q_cuda_graphs: int = 0,
num_kv_cuda_graphs: int = 0,
allow_prefix_sharing: bool = True,
) -> ContinuousBatchingManager:
"""Initialize a manager for continuous batching inference.
@ -1111,7 +1096,6 @@ class ContinuousMixin:
max_queue_size=max_queue_size,
num_q_cuda_graphs=num_q_cuda_graphs,
num_kv_cuda_graphs=num_kv_cuda_graphs,
allow_prefix_sharing=allow_prefix_sharing,
)
# TODO: support streaming
@ -1183,6 +1167,5 @@ class ContinuousMixin:
except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)
finally:
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
manager.stop(block=True, timeout=5.0)
return results

View File

@ -19,7 +19,6 @@ from typing import Optional
import torch
from ...utils import is_torch_xpu_available
from ...utils.logging import logging
from ...utils.metrics import traced
@ -36,13 +35,6 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif is_torch_xpu_available():
device = torch.device("xpu")
torch.xpu.empty_cache()
torch.xpu.synchronize()
total_memory = torch.xpu.get_device_properties(device).total_memory
reserved_memory = torch.xpu.memory_reserved(device)
allocated_memory = torch.xpu.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)
@ -116,10 +108,10 @@ class RequestState:
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
"""
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
# Required fields
request_id: str
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
static_outputs: list[int] = field(default_factory=list) # Generated tokens
allocated_blocks: int = 0 # Number of blocks allocated to the request
@ -164,7 +156,7 @@ class RequestState:
# TODO: this logic seems one token off, check it out
@traced
def update_and_check_completion(self, token_id: int) -> bool:
def update_with_token(self, token_id: int) -> bool:
"""Update the request with a newly generated token and check for completion.
Args:

View File

@ -104,7 +104,7 @@ class Scheduler(ABC):
)
@traced
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
@ -113,11 +113,10 @@ class Scheduler(ABC):
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
len_next_tokens = len(state.prompt_ids)
occupancy = state.allocated_blocks * self.cache.block_size - current_len
if occupancy < len_next_tokens or state.allocated_blocks == 0:
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state)
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
if allocated is None:
return False
state.allocated_blocks += allocated
@ -126,29 +125,11 @@ class Scheduler(ABC):
@traced(span_name="prepare_request")
def _prepare_request_for_processing(
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
) -> None:
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
pending, this is where we look for a prefix match and split the request if found."""
# If prefix sharing is enabled, we look for a prefix match and split the request if found
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
if prefill_length > 0:
self.active_requests[state.request_id] = state
request_ids_to_remove_from_waiting.add(state.request_id)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
# Even if we match the whole request, we keep at least 1 token to start decoding
prefill_length = min(prefill_length, len(state.prompt_ids) - 1)
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
state.prompt_ids = state.prompt_ids[prefill_length:]
state.position_offset += prefill_length
# If the request has a split prefill, the tokens to process are the remaining prompt ids
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
request_tokens = state.remaining_prompt_ids
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
else:
request_tokens = state.prompt_ids
):
"""Prepares a request for processing in the current batch."""
request_tokens = (
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
)
if len(request_tokens) < token_budget:
# Can process the entire prompt/remainder
if state.status == RequestStatus.PENDING:
@ -171,7 +152,6 @@ class Scheduler(ABC):
state.prompt_ids = request_tokens[:token_budget]
# TODO: further common-ize the two classes
@attach_tracer()
class FIFOScheduler(Scheduler):
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
@ -215,31 +195,30 @@ class FIFOScheduler(Scheduler):
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
# Add the request to the scheduled requests
scheduled_requests.append(state)
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
complete_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = complete_blocks
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
was_waiting = self.waiting_requests.pop(req_id, None) is not None
if was_waiting:
request_ids_to_remove_from_waiting.add(req_id)
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break
@ -270,7 +249,6 @@ class PrefillFirstScheduler(Scheduler):
elif state.status == RequestStatus.DECODING:
second_priority_states.append(state)
# Add waiting requests to second priority
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
@ -281,31 +259,30 @@ class PrefillFirstScheduler(Scheduler):
for state in candidates:
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
# Add the request to the scheduled requests
scheduled_requests.append(state)
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
complete_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = complete_blocks
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break

View File

@ -411,7 +411,7 @@ class GenerationMixin(ContinuousMixin):
"Generation config file not found, using a generation config created from the model config."
)
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate") and trust_remote_code:
if hasattr(self, "load_custom_generate"):
try:
custom_generate = self.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
@ -608,7 +608,7 @@ class GenerationMixin(ContinuousMixin):
use_cache = kwargs.get("use_cache")
if use_cache is None:
use_cache = getattr(self.config, "use_cache", False)
if past_key_values is not None or use_cache:
if past_key_values is None or use_cache:
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)
@ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin):
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
for key, value in model_kwargs.items():
if (
value is not None
and key not in model_args
and key not in TransformersKwargs.__optional_keys__
and key != "debug_io"
):
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
unused_model_args.append(key)
if unused_model_args:
@ -2175,7 +2170,7 @@ class GenerationMixin(ContinuousMixin):
return False
# Base logic
valid_hardware = self.device.type in ["cuda", "xpu"] or bool(
valid_hardware = self.device.type == "cuda" or bool(
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
)
using_compilable_cache = (

View File

@ -23,7 +23,6 @@ import torch
from torch import nn
from torch.nn import BCELoss
from .. import initialization as init
from ..modeling_utils import PreTrainedModel
from ..utils import ModelOutput, logging
from .configuration_utils import PreTrainedConfig, WatermarkingConfig
@ -384,11 +383,10 @@ class BayesianDetectorModel(PreTrainedModel):
)
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Parameter):
init.normal_(module.weight, mean=0.0, std=0.02)
module.weight.data.normal_(mean=0.0, std=0.02)
def _compute_posterior(
self,

View File

@ -32,7 +32,6 @@ from .utils import (
is_offline_mode,
is_remote_url,
logging,
safe_load_json_file,
)
from .utils.hub import cached_file
@ -281,41 +280,35 @@ class ImageProcessingMixin(PushToHubMixin):
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
if os.path.isfile(pretrained_model_name_or_path):
resolved_image_processor_file = pretrained_model_name_or_path
resolved_processor_file = None
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
image_processor_file = pretrained_model_name_or_path
resolved_processor_file = None
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
else:
image_processor_file = image_processor_filename
try:
resolved_processor_file = cached_file(
pretrained_model_name_or_path,
filename=PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_image_processor_file = cached_file(
pretrained_model_name_or_path,
filename=image_processor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
# Load from local folder or from cache or download from model Hub and cache
resolved_image_processor_files = [
resolved_file
for filename in [image_processor_file, PROCESSOR_NAME]
if (
resolved_file := cached_file(
pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
)
is not None
]
resolved_image_processor_file = resolved_image_processor_files[0]
except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
@ -329,24 +322,16 @@ class ImageProcessingMixin(PushToHubMixin):
f" directory containing a {image_processor_filename} file"
)
# Load image_processor dict. Priority goes as (nested config if found -> image processor config)
# We are downloading both configs because almost all models have a `processor_config.json` but
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
image_processor_dict = None
if resolved_processor_file is not None:
processor_dict = safe_load_json_file(resolved_processor_file)
if "image_processor" in processor_dict:
image_processor_dict = processor_dict["image_processor"]
try:
# Load image_processor dict
with open(resolved_image_processor_file, encoding="utf-8") as reader:
text = reader.read()
image_processor_dict = json.loads(text)
image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict)
if resolved_image_processor_file is not None and image_processor_dict is None:
image_processor_dict = safe_load_json_file(resolved_image_processor_file)
if image_processor_dict is None:
except json.JSONDecodeError:
raise OSError(
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {image_processor_filename} file"
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
)
if is_local:

View File

@ -821,26 +821,14 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
return image
def _cast_tensor_to_float(x):
if x.is_floating_point():
return x
return x.float()
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
"""
Helper function to flatten a single level of nested image and batch structures and group by shape.
Args:
nested_images (list):
A list of images or a single tensor
paired_inputs (Any, *optional*):
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
they do not need to be tensors.
is_nested (bool, *optional*, defaults to False):
Whether the images are nested.
Returns:
tuple[dict, ...]:
- A dictionary with shape as key and list of images with that shape as value
- A dictionary with shape as key and list of paired values with that shape as value
- A dictionary mapping original indices to (shape, index) tuples
- A dictionary mapping original indices to (shape, index) tuples for each paired input
"""
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
grouped_images = defaultdict(list)
grouped_images_index = {}
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
@ -892,20 +880,27 @@ def _reconstruct_nested_structure(indices, processed_images):
return result
def _iterate_items(items, is_nested: bool):
"""
Helper function to iterate over items yielding (key, item) pairs.
def _disable_grouping_output_nested(images, *paired_inputs):
"""Build the disable_grouping output tuple for a single-level nested structure."""
outer_range = range(len(images))
inner_ranges = [range(len(images[i])) for i in outer_range]
For nested structures, yields ((row_index, col_index), item).
For flat structures, yields (index, item).
"""
if is_nested:
for i, row in enumerate(items):
for j, item in enumerate(row):
yield (i, j), item
else:
for i, item in enumerate(items):
yield i, item
# Precompute all (i, j) pairs
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
return images_dict, *paired_dicts, index_map
def _disable_grouping_output_flat(images, *paired_inputs):
"""Build the disable_grouping output tuple for a flat list structure."""
idx_range = range(len(images))
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
index_map = {i: (i, 0) for i in idx_range}
return images_dict, *paired_dicts, index_map
def group_images_by_shape(
@ -925,7 +920,7 @@ def group_images_by_shape(
Args:
images (Union[list["torch.Tensor"], "torch.Tensor"]):
A list of images or a single tensor
paired_inputs (Any, *optional*):
*paired_inputs (Any):
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
@ -949,14 +944,10 @@ def group_images_by_shape(
disable_grouping = device == "cpu"
if disable_grouping:
return (
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
*[
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
for paired_list in paired_inputs
],
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
)
if is_nested:
return _disable_grouping_output_nested(images, *paired_inputs)
else:
return _disable_grouping_output_flat(images, *paired_inputs)
# Handle single level nested structure
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
@ -999,3 +990,14 @@ def reorder_images(
]
return _reconstruct_nested_structure(grouped_images_index, processed_images)
class NumpyToTensor:
"""
Convert a numpy array to a PyTorch tensor.
"""
def __call__(self, image: np.ndarray):
# Same as in PyTorch, we assume incoming numpy images are in HWC format
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()

View File

@ -1,191 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
from contextlib import contextmanager
import torch
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
# in context managers
TORCH_INIT_FUNCTIONS = {
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"constant_": torch.nn.init.constant_,
"ones_": torch.nn.init.ones_,
"zeros_": torch.nn.init.zeros_,
"eye_": torch.nn.init.eye_,
"dirac_": torch.nn.init.dirac_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"trunc_normal_": torch.nn.init.trunc_normal_,
"orthogonal_": torch.nn.init.orthogonal_,
"sparse_": torch.nn.init.sparse_,
}
def uniform_(
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
return tensor
def normal_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
return tensor
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
return tensor
def ones_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
return tensor
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
return tensor
def eye_(tensor: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
return tensor
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
return tensor
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
return tensor
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
return tensor
def kaiming_uniform_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def kaiming_normal_(
tensor: torch.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
)
return tensor
def trunc_normal_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
return tensor
def orthogonal_(
tensor: torch.Tensor,
gain: float = 1,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
return tensor
def sparse_(
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
return tensor
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if not getattr(tensor, "_is_hf_initialized", False):
with torch.no_grad():
return tensor.copy_(other)
return tensor
@contextmanager
def guard_torch_init_functions():
"""
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
and for remote code, we also use this context manager.
"""
originals = defaultdict(dict)
try:
# Replace all torch funcs by the ones in this file
for name in TORCH_INIT_FUNCTIONS.keys():
# Here, we need to check all modules imported, and hot patch all of them, as usually torch does
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules,
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
# `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough
for module in sys.modules.copy().values():
if module and hasattr(module, name):
originals[module][name] = getattr(module, name)
setattr(module, name, globals()[name])
yield
finally:
# Set back the original functions on all modules
for module, functions in originals.items():
for name, func in functions.items():
setattr(module, name, func)

View File

@ -23,8 +23,6 @@ from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING
from safetensors.torch import save_file
from ..utils import (
is_accelerate_available,
is_torch_available,
@ -484,6 +482,19 @@ def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload
dispatch_model(model, **device_map_kwargs)
def get_disk_only_shard_files(device_map, weight_map):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = defaultdict(list)
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondence parameter name to device.
@ -497,57 +508,50 @@ def expand_device_map(device_map, param_names):
def accelerate_disk_offload(
disk_offload_folder: str | None,
checkpoint_files: list[str] | None,
device_map: dict,
expected_keys: list[str],
sharded_metadata: dict | None,
dtype: torch.dtype | None,
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
os.makedirs(disk_offload_folder, exist_ok=True)
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
if disk_offload_folder is None and not is_offloaded_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if is_offloaded_safetensors:
param_device_map = expand_device_map(device_map, expected_keys)
param_device_map = expand_device_map(device_map, checkpoint_keys)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
weight_map = dict.fromkeys(expected_keys, checkpoint_files[0])
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": name,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()
if param_device_map[name] == "disk"
}
# In this case we will resave every offloaded weight
else:
disk_offload_index = {}
return disk_offload_index
def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict:
"""Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is
saved in `safetensors` format."""
if offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either "
"because the weights are not in `safetensors` format, or because the model uses an internal weight format "
"different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in "
"`from_pretrained`."
)
# Write the weight to disk
safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors")
save_file({weight_name: weight}, safetensor_file)
# Update the offloading index
str_dtype = str(weight.dtype).replace("torch.", "")
offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype}
return offload_index
return disk_offload_index, disk_only_shard_files, is_offloaded_safetensors

View File

@ -1,4 +1,5 @@
import inspect
from copy import deepcopy
from inspect import signature
from ..utils import (
@ -23,6 +24,7 @@ if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
@ -149,6 +151,52 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
tied_params = find_tied_parameters(tied_model)
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# If there is not tied weights, we want to keep the lm_headoutput_embedding) in full precision
if not has_tied_params:
output_emb = model.get_output_embeddings()
if output_emb is not None:
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
return list_last_module
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
"""

View File

@ -11,6 +11,7 @@
# specific language governing permissions and limitations under the License.
import logging
from collections.abc import Callable
from typing import Optional
import torch
@ -23,7 +24,13 @@ from ..cache_utils import (
StaticCache,
)
from ..generation.configuration_utils import GenerationConfig
from ..modeling_utils import PreTrainedModel
from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_ignore_causal_mask_sdpa,
_is_torch_greater_or_equal_than_2_5,
prepare_padding_mask,
)
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
@ -222,6 +229,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
def forward(
self,
@ -757,6 +768,11 @@ def convert_and_export_with_cache(
import torch.export._trace
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
@ -1020,6 +1036,11 @@ def export_with_dynamic_cache(
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
register_dynamic_cache_export_support()
with torch.no_grad():
@ -1088,3 +1109,92 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache
def sdpa_mask_without_vmap(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Optional[Callable] = None,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
return None
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
reshaped_cache_position = cache_position.view(-1, 1)
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
# the config through kwargs anyway, so it allows to rely on it
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
# but this is more efficient
sliding_window = getattr(kwargs["config"], "sliding_window", None)
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
if sliding_window is not None and chunk_size is not None:
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
# If using sliding window, add the sliding mask
if sliding_window is not None:
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
causal_mask *= sliding_mask_overlay
# If using chunk attention, add the chunked mask
elif chunk_size is not None:
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
causal_mask *= chunked_mask_overlay
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask

View File

@ -13,11 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Optional
from ..core_model_loading import ConversionOps
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
@ -33,18 +30,6 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
try:
_FP8_DTYPE = torch.float8_e4m3fn
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
_FP8_IS_INT = False
except AttributeError:
_FP8_DTYPE = torch.int8
_FP8_MIN, _FP8_MAX = -127, 127
_FP8_IS_INT = True
logger.warning_once(
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@ -347,12 +332,6 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
if isinstance(self.weight, torch.distributed.tensor.DTensor):
weight = self.weight._local_tensor.contiguous()
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
else:
weight = self.weight.contiguous()
scale_inv = self.weight_scale_inv.contiguous()
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
@ -360,9 +339,9 @@ class FP8Linear(nn.Linear):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
self.weight,
scale,
scale_inv,
self.weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
@ -371,124 +350,9 @@ class FP8Linear(nn.Linear):
torch_accelerator_module.synchronize()
if self.bias is not None:
output = output + self.bias
output = torch.nan_to_num(output, nan=0.0)
return output.to(dtype=input.dtype)
def _ceil_div(a, b):
return (a + b - 1) // b
class FP8Expert(nn.Module):
dtype = torch.float8_e4m3fn
def __init__(self, config, block_size, device):
super().__init__()
from ..activations import ACT2FN
self.block_size = block_size
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
self.gate_up_proj = nn.Parameter(
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
)
self.down_proj = nn.Parameter(
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
)
# Create inverse scale tiles only when using 1-byte types (fp8)
if self.gate_up_proj.element_size() == 1:
bo, bi = self.block_size
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
)
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
)
else:
# Match FP8Linear behavior when not using 1-byte weights
self.register_parameter("gate_up_proj_scale_inv", None)
self.register_parameter("down_proj_scale_inv", None)
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
self.register_parameter("gate_up_bias", None)
self.register_parameter("down_bias", None)
# Activation used in the MLP (same as your config / ACT2FN)
# Keep a handle here; actual usage happens in forward of your MoE block
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states.index_select(0, token_idx)
gate, up = self.linear(
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.linear(
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
)
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(input, weight, None)
else:
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
return output.to(dtype=input.dtype)
# TODO: we do need this.... but not recursive...
def _replace_with_fp8_linear(
model,
tp_plan=None,
@ -497,48 +361,40 @@ def _replace_with_fp8_linear(
quantization_config=None,
has_been_replaced=False,
):
iterator = list(model.named_parameters()).copy()
for name, empty_tensor in iterator:
current_key_name = name
name = name.rsplit(".", 1)[0] if "." in name else name
module = model.get_submodule(name)
"""Replace Linear layers with FP8Linear."""
if current_key_name is None:
current_key_name = []
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
if (
"gate_up_proj" in current_key_name
or "down_proj" in current_key_name
and "experts" in current_key_name
): # Experts!
in_features = empty_tensor.size(-2)
out_features = empty_tensor.size(-1)
model.set_submodule(
name,
FP8Expert(
config=model.config,
block_size=quantization_config.weight_block_size,
device=empty_tensor.device,
),
)
for name, module in model.named_children():
current_key_name.append(name)
elif isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
model.set_submodule(
name,
FP8Linear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
),
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
tp_plan,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
@ -549,7 +405,7 @@ def replace_with_fp8_linear(
quantization_config=None,
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert += ["lm_head"]
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
@ -568,133 +424,3 @@ def replace_with_fp8_linear(
)
return model
class QuantizationOp(ConversionOps):
"""Base class for quantization operations."""
pass
class Fp8Quantize(QuantizationOp):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""
reverse_op: type[ConversionOps]
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Dequantize
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if quant_config is not None:
if isinstance(quant_config, dict):
block_size = quant_config.get("weight_block_size")
else:
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])
block_m, block_n = block_size
rows, cols = value.shape[-2], value.shape[-1]
# Enforce exact tiling like your original
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
)
# Leading dims can be empty (2D) or include num_experts/... (3D+)
leading_shape = value.shape[:-2]
rows_tiles = rows // block_m
cols_tiles = cols // block_n
original_shape = value.shape
value_fp32 = value.to(torch.float32)
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
# Per-tile max-abs over the block dims
# dims: block_m is at -3, block_n is at -1 after the reshape
max_abs = reshaped.abs().amax(dim=(-3, -1))
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
scales = _FP8_MAX / safe_max_abs
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
# Broadcast scales back over the block dims and quantize
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
scaled = reshaped * scales_broadcast
if _FP8_IS_INT:
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
else:
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
if target_keys.endswith("weight"):
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
else:
scale_key = target_keys + "_scales_inv"
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
return {
target_keys: quantized,
scale_key: inv_scales,
}
class Fp8Dequantize(QuantizationOp):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Quantize
def convert(
self,
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
*,
context: dict[str, Any],
) -> torch.Tensor:
if isinstance(value, dict):
tensors = list(value.values())
else:
tensors = list(value) if isinstance(value, Sequence) else [value]
if len(tensors) != 2:
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
quantized, scales = tensors
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
raise TypeError("Fp8Dequantize expects tensors as inputs.")
quantized_fp32 = quantized.to(torch.float32)
rows, cols = quantized_fp32.shape[-2:]
block_size = self.block_size
if block_size is None:
quant_config = context.get("quantization_config")
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (rows, cols)
block_m, block_n = block_size
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
)
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
dequantized = reshaped * expanded_scales
return dequantized.reshape(quantized_fp32.shape)

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from collections.abc import Callable
from functools import partial
@ -19,7 +18,7 @@ from types import ModuleType
from typing import Optional, Union
from ..modeling_flash_attention_utils import lazy_import_flash_attention
from ..utils import ENV_VARS_TRUE_VALUES, logging
from ..utils import logging
from ..utils.import_utils import is_kernels_available
from .flash_attention import flash_attention_forward
@ -34,22 +33,10 @@ try:
get_kernel,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
_kernels_available = True
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
def use_kernel_forward_from_hub(layer_name: str):
if _kernels_enabled:
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
return _kernels_use_kernel_forward_from_hub(layer_name)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda cls: cls
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"MultiScaleDeformableAttention": {
@ -84,12 +71,6 @@ try:
layer_name="RMSNorm",
)
},
"npu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
)
},
},
"MLP": {
"cuda": LayerRepository(
@ -180,7 +161,6 @@ try:
except ImportError:
_kernels_available = False
_kernels_enabled = False
# Stub to make decorators int transformers work when `kernels`
# is not installed.

View File

@ -38,7 +38,7 @@ from transformers.utils.import_utils import _is_package_available
if os.getenv("WANDB_MODE") == "offline":
print("[INFO] Running in WANDB offline mode")
print("⚙️ Running in WANDB offline mode")
from .. import PreTrainedModel, TrainingArguments
from .. import __version__ as version

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
@ -114,9 +114,6 @@ def convert_moe_packed_tensors(
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
blocks = blocks.to("xpu")
scales = scales.to("xpu")
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
@ -354,8 +351,6 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
if target_device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif target_device == "cpu" and is_torch_xpu_available():
torch.xpu.empty_cache()
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
delattr(module, blocks_attr)
delattr(module, scales_attr)
@ -400,7 +395,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
else:
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
if getattr(target_device, "type", target_device) == "cpu":
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
target_device = "cuda"
blocks = blocks.to(target_device).contiguous()
scales = scales.to(target_device).contiguous()
with on_device(target_device):

View File

@ -236,7 +236,7 @@ class PeftAdapterMixin:
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

View File

@ -63,6 +63,9 @@ def sdpa_attention_forward(
else:
sdpa_kwargs = {"enable_gqa": True}
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)

View File

@ -18,7 +18,6 @@ import operator
import os
import re
from functools import partial, reduce
from typing import Optional
import torch
import torch.distributed as dist
@ -141,16 +140,6 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
return [single_size] * blocks
def replace_layer_number_by_wildcard(name: str) -> str:
"""
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
a dot (`.`) and the end of the string.
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
"""
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
@ -161,11 +150,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
return tp_plan[module_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
return None
@ -317,7 +306,7 @@ def repack_weights(
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
@ -369,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.ndim
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
# Flatten the mesh to get the total number of devices
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
dim = 0
shard_size = math.ceil(empty_param.size(dim) / world_size)
start = rank * shard_size
end = min(start + shard_size, empty_param.size(dim))
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
dimensions = param.get_shape()
shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param.get_shape())
if start < param.get_shape()[dim]:
# Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim
if start < empty_param.shape[dim]:
slice_indices[dim] = slice(start, end)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
return param[tuple(slice_indices)]
dimensions = list(param.shape)
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
return torch.empty(tuple(dimensions), dtype=torch.int64)
def distribute_module(
@ -446,19 +410,6 @@ class TensorParallelLayer:
"""
use_dtensor = True
device_mesh = None
rank = None
# Used to compare the shape of the original tensor
empty_param = None
# Used to init the corresponding DTensor
shard = None
def __init__(self, device_mesh=None, rank=None, empty_param=None):
self.rank = rank
self.device_mesh = device_mesh
self.empty_param = empty_param
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
@ -488,12 +439,12 @@ class GatherParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = output_layouts
self.desired_input_layouts = (Replicate(),)
@ -514,21 +465,6 @@ class GatherParallel(TensorParallelLayer):
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
shard = [Replicate()]
parameter = param[...].to(param_casting_dtype)
self.shard = shard
return parameter, shard
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
@ -557,23 +493,6 @@ class IsolatedParallel(TensorParallelLayer):
# TODO: figure out dynamo support for instance method and switch this to instance method
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
mesh = device_mesh or self.device_mesh
parameter = param[...].to(param_casting_dtype)
if mesh is not None:
parameter = parameter / mesh.size()
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
@ -596,8 +515,8 @@ class ReplicateParallel(TensorParallelLayer):
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
"""
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
super().__init__(**kwargs)
def __init__(self, *, use_dtensor=True, use_local_output=True):
super().__init__()
self.input_layouts = (Replicate(),)
self.output_layouts = (Replicate(),)
self.desired_input_layouts = (Replicate(),)
@ -618,33 +537,12 @@ class ReplicateParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
shard = [Replicate()]
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
parameter, shard = self.shard_tensor(
param,
param_type=param_type,
param_casting_dtype=param_casting_dtype,
to_contiguous=to_contiguous,
rank=rank,
device_mesh=device_mesh,
)
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return parameter
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
return param
class ColwiseParallel(TensorParallelLayer):
@ -654,13 +552,13 @@ class ColwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Shard(-1),)
self.desired_input_layouts = (Replicate(),)
@ -680,34 +578,18 @@ class ColwiseParallel(TensorParallelLayer):
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = self.device_mesh
empty_param = self.empty_param
rank = self.rank
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
parameter = parameter.to(param_casting_dtype)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
@ -726,21 +608,6 @@ class ColwiseParallel(TensorParallelLayer):
class PackedColwiseParallel(ColwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -775,41 +642,18 @@ class RowwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Shard(-1),)
self.output_layouts = (output_layouts or Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
if param_type == "bias":
shard = [Replicate()]
parameter = param[...]
else:
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
shard = [Shard(-1)]
parameter = parameter.to(param_casting_dtype)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -881,21 +725,6 @@ class RowwiseParallel(TensorParallelLayer):
class PackedRowwiseParallel(RowwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -954,8 +783,8 @@ class SequenceParallel(TensorParallelLayer):
to ensure that they are replicated.
"""
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
super().__init__(**kwargs)
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
@ -964,21 +793,6 @@ class SequenceParallel(TensorParallelLayer):
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
shard = [Replicate()]
self.shard = shard
return parameter, shard
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
@ -1013,34 +827,10 @@ class GroupedGemmParallel(TensorParallelLayer):
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self):
super().__init__()
self.use_dtensor = False
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
empty_param = self.empty_param
ep_rank = self.rank
device_mesh = self.device_mesh
global_num_experts = empty_param.shape[0]
if global_num_experts % device_mesh.size() != 0:
raise ValueError(
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
)
local_num_experts = global_num_experts // device_mesh.size()
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
ep_rank = rank
global_num_experts = empty_param.shape[0]
@ -1061,8 +851,8 @@ class RouterParallel(TensorParallelLayer):
"""
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.args = args
self.kwargs = kwargs
self.use_dtensor = False
@staticmethod
@ -1127,20 +917,6 @@ class RouterParallel(TensorParallelLayer):
) # masking class for one hot
return router_scores, router_indices
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# TODO: i'd like for this to be the default
param = param[...].to(param_casting_dtype)
@ -1283,9 +1059,6 @@ def shard_and_distribute_module(
if current_shard_plan is not None:
try:
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
tp_layer.empty_param = empty_param
tp_layer.device_mesh = device_mesh
tp_layer.rank = rank
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
@ -1313,7 +1086,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
if tp_plan is None:
return
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
unsharded_layers = set(generic_keys)
unused_rules = tp_plan

View File

@ -82,10 +82,8 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
"""
This creates a full bidirectional mask.
NOTE: It is important to keep an index-based version for non-vmap expansion.
"""
return q_idx >= 0
return q_idx.new_ones((), dtype=torch.bool)
def sliding_window_overlay(sliding_window: int) -> Callable:
@ -112,6 +110,18 @@ def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
return inner_mask
def _legacy_chunked_overlay(chunk_size: int) -> Callable:
"""
Same as the above function, but do not correctly account for left padding tokens.
Only kept for compatibility with older torch versions (< 2.6).
"""
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return kv_idx // chunk_size == q_idx // chunk_size
return inner_mask
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
"""
This return the mask_function function to create a sliding window mask.
@ -123,6 +133,8 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) ->
"""
This return the mask_function function to create a chunked attention mask.
"""
if not _is_torch_greater_or_equal_than_2_6:
return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
@ -163,56 +175,55 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs
return inner_mask
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
"""
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
the batch and head indices as well if `bh_indices=True`.
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
Args:
mask_function (`Callable`):
The mask_function to vmap.
bh_indices (`bool`, optional):
Whether to vmap over the batch and head indices as well, or only q and kv indices.
Returns:
Callable: The vmapped function.
"""
# We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
dimensions = [(None, None, None, 0), (None, None, 0, None)]
if bh_indices:
# We extend broadcasting over the [batch_idx, head_idx] dimensions
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
for dims in dimensions:
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function
def prepare_padding_mask(
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
) -> Optional[torch.Tensor]:
"""
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
according to the `kv_offset` if `_slice` is `True`.
"""
local_padding_mask = attention_mask
if attention_mask is not None:
# Pad it if necessary
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
# For flex, we should not slice them, only use an offset
if _slice:
# Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indices = torch.arange(kv_length, device=local_padding_mask.device)
mask_indices += kv_offset
local_padding_mask = local_padding_mask[:, mask_indices]
return local_padding_mask
def _can_skip_causal_mask_xpu(
padding_mask: Optional[torch.Tensor],
query_length: int,
kv_length: int,
local_attention_size: Optional[int],
) -> bool:
"""
XPU-specific logic for determining if we can skip causal mask creation.
For XPU devices, we have special handling:
- Single query tokens (query_length == 1) use the same logic as CUDA
- Multi-query tokens can skip if padding_mask is provided and correctly structured
The mask must have all True values in the query window and all False after
"""
if is_tracing(padding_mask):
return False
# Check local attention constraint (same as CUDA)
if local_attention_size is not None and kv_length >= local_attention_size:
return False
if padding_mask is None:
# Without padding mask, can skip if single query token or full causal attention
return query_length == 1 or kv_length == query_length
# XPU allows skipping under additional conditions when padding_mask is provided
if query_length == 1:
# Single query token: skip only if no padding tokens present
return padding_mask.all()
# XPU-specific: check if query window is all True and rest is all False
# This allows XPU to optimize the 1st token in static cache
return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
def _ignore_causal_mask_sdpa(
padding_mask: Optional[torch.Tensor],
query_length: int,
@ -233,12 +244,6 @@ def _ignore_causal_mask_sdpa(
mask_indices += kv_offset
padding_mask = padding_mask[:, mask_indices]
if _is_torch_xpu_available:
# XPU devices have special handling for mask skipping:
# - Single query tokens use the same logic as CUDA
# - Multi-query tokens can skip if padding_mask is provided and correctly structured
# (all True in query window, all False after)
return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
@ -246,11 +251,18 @@ def _ignore_causal_mask_sdpa(
if (
not is_tracing(padding_mask)
# only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
and (query_length == 1 or kv_length == query_length)
and (query_length == 1 or (kv_length == query_length or _is_torch_xpu_available))
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
and (local_attention_size is None or kv_length < local_attention_size)
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
and (padding_mask is None or padding_mask.all())
and (
padding_mask is None
or (
padding_mask.all()
if not _is_torch_xpu_available or query_length == 1
else padding_mask[:, :query_length].all()
)
)
):
return True
@ -270,39 +282,7 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo
return False
def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
"""
Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
"""
# We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
for dims in dimensions:
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function
def _non_vmap_expansion_sdpa(
batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
):
"""
Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
Allows the usage of any index-based mask function without relying on vmap.
NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
Reference:
- https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
"""
batch_indices = batch_indices[:, None, None, None]
head_indices = head_indices[None, :, None, None]
q_indices = q_indices[None, None, :, None]
kv_indices = kv_indices[None, None, None, :]
return batch_indices, head_indices, q_indices, kv_indices
def sdpa_mask(
def sdpa_mask_recent_torch(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
@ -312,8 +292,6 @@ def sdpa_mask(
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_is_bidirectional_skip: bool = False,
allow_torch_fix: bool = True,
use_vmap: bool = False,
**kwargs,
) -> Optional[torch.Tensor]:
"""
@ -346,12 +324,6 @@ def sdpa_mask(
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
## Creating a simple causal mask:
@ -419,8 +391,97 @@ def sdpa_mask(
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
# Potentially pad the 2D mask
# Under specific conditions, we can avoid materializing the mask
# 1. Causal masks can rely on the `is_causal` argument
# 2. Bidirectional do not need any further processing (no bias)
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
return None
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
# used for slicing without data-dependent slicing
mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
batch_arange = torch.arange(batch_size, device=cache_position.device)
head_arange = torch.arange(1, device=cache_position.device)
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
with TransformGetItemToIndex():
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
return causal_mask
def sdpa_mask_older_torch(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
allow_is_bidirectional_skip: bool = False,
**kwargs,
) -> Optional[torch.Tensor]:
"""
NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend
to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does
not change the final result).
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
# Under specific conditions, we can avoid materializing the mask
@ -431,45 +492,38 @@ def sdpa_mask(
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None
# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
batch_arange = torch.arange(batch_size, device=cache_position.device)
head_arange = torch.arange(1, device=cache_position.device)
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
# Actual mask creation
# Option 1: Fast non-vmap mask creation (default)
if not use_vmap:
# Apply mask function element-wise through broadcasting
attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange))
# Expand the mask to match batch size and query length if they weren't used in the mask function
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
elif _is_torch_greater_or_equal_than_2_6:
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
with TransformGetItemToIndex():
attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
# Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
else:
raise ValueError(
"The vmap functionality for mask creation is only supported from torch>=2.6. "
"Please update your torch version or use `use_vmap=False` with index-based masks."
)
# This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask
return attention_mask
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
# (especially mask_function indexing a tensor, such as the padding mask function)
sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
def eager_mask(
@ -480,7 +534,6 @@ def eager_mask(
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
use_vmap: bool = False,
**kwargs,
) -> torch.Tensor:
"""
@ -503,14 +556,10 @@ def eager_mask(
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`torch.dtype`, optional):
The dtype to use for the mask. By default, `torch.float32`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
_ = kwargs.pop("allow_is_bidirectional_skip", None)
_ = kwargs.pop("allow_torch_fix", None)
mask = sdpa_mask(
batch_size=batch_size,
cache_position=cache_position,
@ -521,7 +570,6 @@ def eager_mask(
allow_is_causal_skip=False,
allow_is_bidirectional_skip=False,
allow_torch_fix=False,
use_vmap=use_vmap,
**kwargs,
)
min_dtype = torch.finfo(dtype).min
@ -607,7 +655,7 @@ def flex_attention_mask(
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
# Add the offsets on top (because flex interface only allows length, not start and end indices)
@ -733,19 +781,9 @@ def _preprocess_mask_arguments(
# If using a cache, it can give all information about mask sizes based on seen tokens
if past_key_values is not None:
kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx)
# Otherwise, we infer based on our input
# Otherwise, the sizes are simply the input sizes
else:
# 1. Rely on input directly
if attention_mask is None:
kv_length, kv_offset = input_embeds.shape[1], 0
# 2. Rely on the mask instead - needed for special cases like prefix tuning in PEFT
#
# This is a very unique and special case where an encoder utilizes a cache and expects its length
# to be accounted for (usually, they should never use a cache). In general, the mask should always
# match with the input sizes nonetheless (i.e. it does not affect others).
# Conclusion: "prefix tuning is evil"
else:
kv_length, kv_offset = attention_mask.shape[-1], 0
kv_length, kv_offset = input_embeds.shape[1], 0
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
# and we don't have past_key_values, i.e. generally a training setup)
@ -813,11 +851,6 @@ def create_causal_mask(
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
if _is_torch_xpu_available:
@ -834,16 +867,14 @@ def create_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -858,7 +889,6 @@ def create_causal_mask(
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask
@ -912,10 +942,6 @@ def create_bidirectional_mask(
# Allow skipping the mask creation except we have additional masking operators (and/or masks)
allow_is_bidirectional_skip = True
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Allow slight deviations from the base mask
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
@ -925,13 +951,11 @@ def create_bidirectional_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True
# We now create the mask
attention_mask = mask_interface(
@ -946,7 +970,6 @@ def create_bidirectional_mask(
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return attention_mask
@ -1009,10 +1032,6 @@ def create_sliding_window_causal_mask(
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
@ -1025,16 +1044,14 @@ def create_sliding_window_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -1050,7 +1067,6 @@ def create_sliding_window_causal_mask(
local_size=sliding_window, # Additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask
@ -1124,13 +1140,20 @@ def create_chunked_causal_mask(
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
else:
left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
# Raise a warning for older versions if the problematic left-padding situation arises
if (
not _is_torch_greater_or_equal_than_2_6
and kv_length + kv_offset > chunk_size
and (left_padding_tokens > 0).any()
):
logger.warning_once(
"Due to limitations of your current torch version, we cannot correctly account for the left-padding "
"when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
"sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
)
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
@ -1143,16 +1166,14 @@ def create_chunked_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -1168,7 +1189,6 @@ def create_chunked_causal_mask(
local_size=chunk_size, # Additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,6 @@ if TYPE_CHECKING:
from .arcee import *
from .aria import *
from .audio_spectrogram_transformer import *
from .audioflamingo3 import *
from .auto import *
from .autoformer import *
from .aya_vision import *
@ -142,9 +141,6 @@ if TYPE_CHECKING:
from .git import *
from .glm import *
from .glm4 import *
from .glm4v import *
from .glm4v_moe import *
from .glm46v import *
from .glpn import *
from .got_ocr2 import *
from .gpt2 import *

View File

@ -29,7 +29,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
@ -38,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
@ -406,14 +406,13 @@ class Aimv2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_flex_attn = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
init.constant_(module.logit_scale, math.log(1 / 0.07))
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring(
@ -446,11 +445,13 @@ class Aimv2VisionModel(Aimv2PreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed
@deprecate_kwarg("attention_mask", version="v4.58.0")
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""

View File

@ -22,7 +22,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ... import initialization as init
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
@ -33,6 +32,7 @@ from ...utils import (
auto_docstring,
can_return_tuple,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
@ -449,14 +449,13 @@ class Aimv2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_flex_attn = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
init.constant_(module.logit_scale, math.log(1 / 0.07))
module.logit_scale.data.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring(
@ -489,11 +488,13 @@ class Aimv2VisionModel(Aimv2PreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed
@deprecate_kwarg("attention_mask", version="v4.58.0")
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""

View File

@ -22,7 +22,6 @@ import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ... import initialization as init
from ...activations import ACT2FN
from ...masking_utils import create_bidirectional_mask
from ...modeling_outputs import (
@ -303,23 +302,21 @@ class AlbertPreTrainedModel(PreTrainedModel):
"attentions": AlbertAttention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
init.zeros_(module.bias)
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, AlbertMLMHead):
init.zeros_(module.bias)
module.bias.data.zero_()
@dataclass
@ -428,10 +425,7 @@ class AlbertModel(AlbertPreTrainedModel):
"""
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = {
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
"predictions.decoder.bias": "predictions.bias",
}
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
def __init__(self, config: AlbertConfig):
super().__init__(config)
@ -531,6 +525,7 @@ class AlbertMLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
self.decoder.bias = self.bias
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
@ -542,6 +537,14 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module):
def __init__(self, config: AlbertConfig):
@ -558,10 +561,7 @@ class AlbertSOPHead(nn.Module):
@auto_docstring
class AlbertForMaskedLM(AlbertPreTrainedModel):
_tied_weights_keys = {
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
"predictions.decoder.bias": "predictions.bias",
}
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)

View File

@ -22,7 +22,6 @@ from typing import Any, Optional, Union
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
@ -824,26 +823,24 @@ class AlignPreTrainedModel(PreTrainedModel):
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
@torch.no_grad()
def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
init.normal_(module.weight, mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
module.bias.data.zero_()
elif isinstance(module, AlignModel):
init.xavier_uniform_(module.text_projection.weight)
init.zeros_(module.text_projection.bias)
init.constant_(module.temperature, self.config.temperature_init_value)
nn.init.xavier_uniform_(module.text_projection.weight)
module.text_projection.bias.data.zero_()
module.temperature.data.fill_(self.config.temperature_init_value)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
init.zeros_(module.bias)
init.ones_(module.weight)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@auto_docstring(

View File

@ -22,7 +22,6 @@ from typing import Any, Optional, Union
import torch
import torch.nn as nn
from ... import initialization as init
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
@ -771,50 +770,50 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_module = []
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, AltCLIPVisionEmbeddings):
factor = self.config.initializer_factor
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, AltCLIPAttention):
factor = self.config.initializer_factor
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim**-0.5) * factor
init.normal_(module.q_proj.weight, std=in_proj_std)
init.normal_(module.k_proj.weight, std=in_proj_std)
init.normal_(module.v_proj.weight, std=in_proj_std)
init.normal_(module.out_proj.weight, std=out_proj_std)
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, AltCLIPMLP):
factor = self.config.initializer_factor
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
init.normal_(module.fc1.weight, std=fc_std)
init.normal_(module.fc2.weight, std=in_proj_std)
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, AltCLIPModel):
init.normal_(
nn.init.normal_(
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
init.normal_(
module.text_projection._is_hf_initialized = True
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
module.visual_projection._is_hf_initialized = True
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_factor)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
if module.bias is not None:
init.zeros_(module.bias)
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_factor)
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
init.zeros_(module.weight[module.padding_idx])
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class AltCLIPVisionTransformer(nn.Module):

View File

@ -17,6 +17,7 @@ Image/Text processor class for AltCLIP
"""
from ...processing_utils import ProcessorMixin
from ...utils.deprecation import deprecate_kwarg
class AltCLIPProcessor(ProcessorMixin):
@ -34,6 +35,7 @@ class AltCLIPProcessor(ProcessorMixin):
The tokenizer is a required input.
"""
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)

View File

@ -106,6 +106,7 @@ class ApertusConfig(PreTrainedConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -429,7 +429,7 @@ class ApertusModel(ApertusPreTrainedModel):
@auto_docstring
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

View File

@ -123,6 +123,7 @@ class ApertusConfig(LlamaConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
def __init__(

View File

@ -434,7 +434,7 @@ class ArceeModel(ArceePreTrainedModel):
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

View File

@ -99,14 +99,15 @@ class AriaTextConfig(PreTrainedConfig):
model_type = "aria_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `AriaTextModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -25,7 +25,6 @@ from typing import Optional, Union
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
@ -586,11 +585,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaGroupedExpertsGemm):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring
@ -610,11 +608,10 @@ class AriaPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaProjector):
init.trunc_normal_(module.query, std=self.config.initializer_range)
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
class AriaTextRotaryEmbedding(nn.Module):
@ -763,7 +760,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
@auto_docstring
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@ -893,6 +890,8 @@ class AriaModelOutputWithPast(BaseModelOutputWithPast):
"""
)
class AriaModel(AriaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@ -1049,12 +1048,12 @@ class AriaModel(AriaPreTrainedModel):
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
r"^language_model.model": "model.language_model",
r"^vision_tower": "model.vision_tower",
r"^multi_modal_projector": "model.multi_modal_projector",
r"^language_model.lm_head": "lm_head",
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)

View File

@ -19,7 +19,6 @@ import numpy as np
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
@ -170,15 +169,6 @@ class AriaTextConfig(LlamaConfig):
model_type = "aria_text"
base_config_key = "text_config"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
}
def __init__(
self,
@ -1197,11 +1187,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaGroupedExpertsGemm):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
class AriaPreTrainedModel(LlamaPreTrainedModel):
@ -1210,11 +1199,10 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
if isinstance(module, AriaProjector):
init.trunc_normal_(module.query, std=self.config.initializer_range)
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
class AriaTextModel(LlamaModel):
@ -1228,7 +1216,7 @@ class AriaTextModel(LlamaModel):
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1367,8 +1355,6 @@ class AriaModel(LlavaModel):
"""
)
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def get_image_features(
self,
pixel_values: torch.FloatTensor,

View File

@ -272,9 +272,7 @@ if __name__ == "__main__":
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model to the Hugging Face hub.",
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()

View File

@ -20,7 +20,6 @@ from typing import Optional, Union
import torch
from torch import nn
from ... import initialization as init
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
@ -301,20 +300,23 @@ class ASTPreTrainedModel(PreTrainedModel):
"attentions": ASTSelfAttention,
}
@torch.no_grad()
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
init.zeros_(module.bias)
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
init.zeros_(module.bias)
init.ones_(module.weight)
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ASTEmbeddings):
init.zeros_(module.cls_token)
init.zeros_(module.position_embeddings)
init.zeros_(module.distillation_token)
module.cls_token.data.zero_()
module.position_embeddings.data.zero_()
module.distillation_token.data.zero_()
@auto_docstring

View File

@ -1,31 +0,0 @@
# coding=utf-8
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_audioflamingo3 import *
from .modeling_audioflamingo3 import *
from .processing_audioflamingo3 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -1,210 +0,0 @@
# coding=utf-8
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class AudioFlamingo3EncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`AudioFlamingo3Encoder`]. It is used to instantiate an
AudioFlamingo3 audio encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the audio encoder of the AudioFlamingo3
architecture.
e.g. [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_mel_bins (`int`, *optional*, defaults to 128):
Number of mel features used per input features. Should correspond to the value used in the
`AudioFlamingo3Processor` class.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of encoder layers.
num_attention_heads (`int`, *optional*, defaults to 20):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 5120):
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](https://huggingface.co/papers/1909.11556)
for more details.
activation_function (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_size (`int`, *optional*, defaults to 1280):
Dimensionality of the layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by dividing by sqrt(hidden_size).
max_source_positions (`int`, *optional*, defaults to 1500):
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
Example:
```python
>>> from transformers import AudioFlamingo3EncoderConfig, AudioFlamingo3Encoder
>>> # Initializing an AudioFlamingo3EncoderConfig
>>> configuration = AudioFlamingo3EncoderConfig()
>>> # Initializing an AudioFlamingo3Encoder (with random weights)
>>> model = AudioFlamingo3Encoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "audioflamingo3_encoder"
attribute_map = {
"d_model": "hidden_size",
"encoder_layers": "num_hidden_layers",
"encoder_attention_heads": "num_attention_heads",
"encoder_ffn_dim": "intermediate_size",
"encoder_layerdrop": "layerdrop",
}
def __init__(
self,
num_mel_bins=128,
num_hidden_layers=32,
num_attention_heads=20,
intermediate_size=5120,
layerdrop=0.0,
activation_function="gelu",
hidden_size=1280,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
initializer_range=0.02,
scale_embedding=False,
max_source_positions=1500,
**kwargs,
):
super().__init__(**kwargs)
self.num_mel_bins = num_mel_bins
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.initializer_range = initializer_range
self.layerdrop = layerdrop
self.num_hidden_layers = num_hidden_layers
self.scale_embedding = scale_embedding
self.max_source_positions = max_source_positions
class AudioFlamingo3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`AudioFlamingo3ForConditionalGeneration`]. It is used to instantiate an
AudioFlamingo3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the AudioFlamingo3.
e.g. [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
audio_config (`Union[AudioFlamingo3EncoderConfig, dict]`, *optional*, defaults to `AudioFlamingo3EncoderConfig`):
The config object or dictionary of the audio backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
The config object or dictionary of the text backbone.
audio_token_id (`int`, *optional*, defaults to 151669):
The audio token index to encode the audio prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
Activation function used in the projector.
projector_bias (`bool`, *optional*, defaults to `True`):
Whether to include bias terms in the projector.
Example:
```python
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AudioFlamingo3Config, AudioFlamingo3EncoderConfig, Qwen2Config
>>> # Initializing an AudioFlamingo3Encoder config
>>> audio_config = AudioFlamingo3EncoderConfig()
>>> # Initializing a Qwen2 config
>>> text_config = Qwen2Config()
>>> # Initializing an AudioFlamingo3 configuration
>>> configuration = AudioFlamingo3Config(audio_config, text_config)
>>> # Initializing a model from the audioflamingo3 style configuration
>>> model = AudioFlamingo3ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "audioflamingo3"
sub_configs = {
"audio_config": AudioFlamingo3EncoderConfig,
"text_config": AutoConfig,
}
def __init__(
self,
audio_config=None,
text_config=None,
audio_token_id=151669,
projector_hidden_act="gelu",
projector_bias=True,
**kwargs,
):
self.audio_token_id = audio_token_id
if isinstance(audio_config, dict):
audio_config["model_type"] = audio_config.get("model_type", "audioflamingo3_encoder")
audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
elif audio_config is None:
audio_config = CONFIG_MAPPING["audioflamingo3_encoder"]()
self.audio_config = audio_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config.get("model_type", "qwen2")
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["qwen2"]()
self.text_config = text_config
self.projector_hidden_act = projector_hidden_act
self.projector_bias = projector_bias
super().__init__(**kwargs)
__all__ = ["AudioFlamingo3Config", "AudioFlamingo3EncoderConfig"]

View File

@ -1,286 +0,0 @@
# coding=utf-8
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert AudioFlamingo3 checkpoints into a Hugging Face repository layout."""
from __future__ import annotations
import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any
import torch
from safetensors.torch import safe_open
from transformers import (
AudioFlamingo3Config,
AudioFlamingo3ForConditionalGeneration,
AudioFlamingo3Processor,
AutoTokenizer,
GenerationConfig,
Qwen2Config,
WhisperFeatureExtractor,
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
def _load_json(p: Path):
if not p.is_file():
raise FileNotFoundError(f"Missing JSON: {p}")
with p.open("r", encoding="utf-8") as f:
return json.load(f)
def write_processor(src_root: Path, dst_root: Path):
llm_dir = src_root / "llm"
# fmt: off
tokenizer_chat_template = (
"{% if messages[0]['role'] != 'system' %}"
"{{ '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}"
"{% endif %}"
"{% for message in messages if message['content'] is not none %}"
"{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
# fmt: on
# fmt: off
processor_chat_template = (
"{% if messages[0]['role'] != 'system' %}"
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"{% endif %}"
"{% for m in messages if m['content'] is not none %}"
"<|im_start|>{{ m['role'] }}\n"
"{% if m['content'] is string %}"
"{{ m['content'] }}"
"{% else %}"
"{% set audio = namespace(found=False) %}"
"{% set text_buf = namespace(v='') %}"
"{% for c in m['content'] %}"
"{% if c.get('type') == 'audio' or 'audio' in c %}"
"{% set audio.found = True %}"
"{% elif c.get('type') == 'text' or 'text' in c %}"
"{% set text_buf.v = text_buf.v + c['text'] %}"
"{% endif %}"
"{% endfor %}"
"{% if audio.found %}{{ '<sound>' }}{% endif %}{{ text_buf.v }}"
"{% endif %}"
"<|im_end|>\n"
"{% endfor %}"
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)
# fmt: on
processor = AudioFlamingo3Processor(
feature_extractor=WhisperFeatureExtractor(feature_size=128, return_attention_mask=True),
tokenizer=AutoTokenizer.from_pretrained(str(llm_dir), chat_template=tokenizer_chat_template, use_fast=True),
chat_template=processor_chat_template,
)
processor.save_pretrained(str(dst_root))
logger.info("processor (tokenizer + preprocessor)")
return processor
PREFIX_MAP = {
"llm": "language_model",
"sound_tower": "audio_tower",
"sound_mm_projector": "multi_modal_projector",
}
def _resolve_component_dir(dirpath: Path):
if not dirpath.is_dir():
return None
idx = dirpath / "model.safetensors.index.json"
mono = dirpath / "model.safetensors"
if idx.exists():
wm = _load_json(idx).get("weight_map") or {}
by_shard: dict[str, list[str]] = defaultdict(list)
for k, shard in wm.items():
by_shard[shard].append(k)
return ("sharded", dirpath, {k: sorted(v) for k, v in sorted(by_shard.items())})
if mono.exists():
return ("file", mono)
cands = sorted([x for x in dirpath.iterdir() if x.suffix == ".safetensors"])
return ("file", cands[0]) if len(cands) == 1 else None
def merge_and_shard_weights(src_root: Path, dst_root: Path, processor: AudioFlamingo3Processor):
state: dict[str, Any] = {}
for tag in PREFIX_MAP.keys():
comp = _resolve_component_dir(src_root / tag)
if not comp:
continue
out_prefix = PREFIX_MAP.get(tag, tag)
if comp[0] == "file":
fp: Path = comp[1]
with safe_open(str(fp), framework="pt", device="cpu") as f:
for k in f.keys():
if k == "__metadata__":
continue
state[f"{out_prefix}.{k}"] = f.get_tensor(k)
else:
base: Path = comp[1]
shard_map: dict[str, list[str]] = comp[2]
for shard, keys in shard_map.items():
sp = base / shard
with safe_open(str(sp), framework="pt", device="cpu") as f:
for k in keys:
state[f"{out_prefix}.{k}"] = f.get_tensor(k)
if not state:
raise FileNotFoundError("No tensors found in llm/, sound_tower/, or sound_mm_projector/.")
tok = processor.tokenizer
text_config = Qwen2Config(
bos_token_id=tok.bos_token_id,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
vocab_size=len(tok),
hidden_size=3584,
intermediate_size=18944,
model_max_length=8192,
num_attention_heads=28,
num_hidden_layers=28,
num_key_value_heads=4,
rope_theta=1000000.0,
use_cache=False,
)
config = AudioFlamingo3Config(text_config=text_config, audio_token_id=tok.get_vocab()["<sound>"])
model = AudioFlamingo3ForConditionalGeneration(config).to(dtype=torch.bfloat16)
# Update state dict to new key names if necessary
projector_key_mapping = {
"multi_modal_projector.layers.0.weight": "multi_modal_projector.linear_1.weight",
"multi_modal_projector.layers.0.bias": "multi_modal_projector.linear_1.bias",
"multi_modal_projector.layers.2.weight": "multi_modal_projector.linear_2.weight",
"multi_modal_projector.layers.2.bias": "multi_modal_projector.linear_2.bias",
}
for old_key, new_key in projector_key_mapping.items():
if old_key in state:
state[new_key] = state.pop(old_key)
# Load weights into the instantiated model so we can push via `push_to_hub` later.
load_res = model.load_state_dict(state, strict=True)
# Enforce a clean load
if getattr(load_res, "missing_keys", None) and load_res.missing_keys:
mk = load_res.missing_keys
raise ValueError(f"Missing keys when loading: {mk[:10]}{' ...' if len(mk) > 10 else ''}")
if getattr(load_res, "unexpected_keys", None) and load_res.unexpected_keys:
uk = load_res.unexpected_keys
raise ValueError(f"Unexpected keys when loading: {uk[:10]}{' ...' if len(uk) > 10 else ''}")
generation_config = GenerationConfig(
bos_token_id=tok.bos_token_id,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
max_new_tokens=2048,
)
model.generation_config = generation_config
model.save_pretrained(save_directory=str(dst_root))
logger.info("model.safetensors index and shards")
return model
"""
Reproducible Usage
==================
1) Download the original AudioFlamingo-3 weights from NVIDIA (requires Git LFS):
```
git lfs install
git clone https://huggingface.co/nvidia/audio-flamingo-3
```
This will create a folder `audio-flamingo-3/` containing the original components:
`llm/`, `sound_tower/`, and `sound_mm_projector/`.
2) Convert to the Hugging Face Transformers format (locally):
```
python src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py \
--src_dir audio-flamingo-3 \
--dst_dir audio-flamingo-3-hf
```
3) Convert and push directly to the Hub (requires `huggingface-cli login` or `HF_TOKEN`):
```
python src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py \
--src_dir audio-flamingo-3 \
--dst_dir audio-flamingo-3-hf \
--push_to_hub <username-or-org>/audio-flamingo-3
```
This command uploads both the processor (tokenizer + feature extractor) and the converted
model (sharded safetensors + configs) to the specified Hub repository.
"""
def main() -> None:
ap = argparse.ArgumentParser(description="Convert AudioFlamingo3 to Hugging Face format.")
ap.add_argument("--src_dir", required=True, help="Source model root directory")
ap.add_argument("--dst_dir", required=True, help="Destination directory for converted model")
ap.add_argument(
"--push_to_hub",
default=None,
type=str,
help=(
"Optional repository ID to push the converted assets to the Hugging Face Hub, "
"e.g. 'username/audio-flamingo-3'."
),
)
args = ap.parse_args()
src_root = Path(args.src_dir).resolve()
if not src_root.is_dir():
raise FileNotFoundError(f"Source directory not found: {src_root}")
dst_root = Path(args.dst_dir).resolve()
if dst_root.exists():
raise FileExistsError(f"Destination already exists: {dst_root}")
processor = write_processor(src_root, dst_root)
model = merge_and_shard_weights(src_root, dst_root, processor)
# Optionally push converted assets using native push_to_hub only
if args.push_to_hub:
logger.info("Pushing processor to the Hub ...")
processor.push_to_hub(args.push_to_hub)
logger.info("Pushing model to the Hub ...")
model.push_to_hub(args.push_to_hub)
if __name__ == "__main__":
main()

View File

@ -1,603 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/audioflamingo3/modular_audioflamingo3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_audioflamingo3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from typing import Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import eager_mask, padding_mask_function
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig
logger = logging.get_logger(__name__)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None and attention_mask.ndim == 4:
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AudioFlamingo3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
layer_idx: Optional[int] = None,
config: Optional[AudioFlamingo3Config] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. audioflamingo3 is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
is_updated = past_key_values.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_values.is_updated[self.layer_idx] = True
past_key_values = past_key_values.cross_attention_cache
else:
past_key_values = past_key_values.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_values and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
if past_key_values is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=1.0,
output_attentions=output_attentions,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class AudioFlamingo3EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: AudioFlamingo3Config):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = AudioFlamingo3Attention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return hidden_states, attn_weights
@auto_docstring
class AudioFlamingo3PreTrainedModel(PreTrainedModel):
config: AudioFlamingo3Config
base_model_prefix = "model"
input_modalities = ["audio", "text"]
supports_gradient_checkpointing = True
_no_split_modules = ["AudioFlamingo3Attention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
@auto_docstring(
custom_intro="""
The audio model from AudioFlamingo3 without any head or projection on top.
"""
)
class AudioFlamingo3Encoder(AudioFlamingo3PreTrainedModel):
"""
AudioFlamingo3 encoder: Whisper encoder, average pool (time/2), then LayerNorm.
"""
# Ignore copy
config: AudioFlamingo3EncoderConfig
main_input_name = "input_features"
input_modalities = "audio"
_no_split_modules = ["AudioFlamingo3EncoderLayer"]
def __init__(self, config: AudioFlamingo3EncoderConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
self.layers = nn.ModuleList([AudioFlamingo3EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
# Ignore copy
self.avg_pooler = nn.AvgPool1d(2, stride=2)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
@can_return_tuple
def forward(
self,
input_features: torch.Tensor,
input_features_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Log-Mel features extracted from raw audio. Use the processor/feature extractor to compute and pad
these features from waveform input.
input_features_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
"""
# Prepare attention mask for transformer layers
batch_size = input_features.shape[0]
seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling
input_features_lengths = input_features_mask.sum(-1)
input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling
input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None]
attention_mask = eager_mask(
batch_size=batch_size,
cache_position=torch.arange(seq_len, device=input_features.device),
kv_length=seq_len,
mask_function=padding_mask_function(input_features_mask),
dtype=self.conv1.weight.dtype,
)
# Conv front-end
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
# Add positions, dropout
hidden_states = inputs_embeds + self.embed_positions.weight
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# Transformer stack
for layer in self.layers:
drop = self.training and torch.rand([]) < self.layerdrop
if not drop:
hidden_states = layer(hidden_states, attention_mask)[0]
# AvgPool (time/2) + LayerNorm
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1)
hidden_states = self.layer_norm(hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
)
# Ignore copy
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
class AudioFlamingo3MultiModalProjector(nn.Module):
"""
Audio adaptor (small MLP) that projects AudioFlamingo3Encoder features
to the LLM embedding space so they can replace `<sound>` tokens.
"""
def __init__(self, config: AudioFlamingo3Config):
super().__init__()
self.linear_1 = nn.Linear(
config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
)
def forward(self, audio_features):
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@auto_docstring(
custom_intro="""
The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
"""
)
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_tp_plan = None
_pp_plan = None
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_audio_features(
self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor
) -> torch.FloatTensor:
"""
This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
Args:
input_features (`torch.FloatTensor`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padded feature indices.
Returns:
`torch.FloatTensor`:
The audio embeddings.
"""
# Encode audio
encoder_output = self.audio_tower(input_features, input_features_mask=input_features_mask)
audio_embeds = self.multi_modal_projector(encoder_output.last_hidden_state)
# Mask according to avg pooling (which is after attention blocks)
post_lengths = (input_features_mask.sum(-1) - 2) // 2 + 1
valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
audio_embeds = audio_embeds[valid_mask.to(audio_embeds.device)]
return audio_embeds
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
>>> model_id = "nvidia/audio-flamingo-3-hf"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
>>> conversations = [
>>> [
>>> {
>>> "role": "user",
>>> "content": [
>>> {"type": "text", "text": "Transcribe the input speech."},
>>> {
>>> "type": "audio",
>>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
>>> },
>>> ],
>>> }
>>> ],
>>> [
>>> {
>>> "role": "user",
>>> "content": [
>>> {
>>> "type": "text",
>>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
>>> },
>>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
>>> ],
>>> }
>>> ],
>>> ]
>>> inputs = processor.apply_chat_template(
>>> conversations,
>>> tokenize=True,
>>> add_generation_prompt=True,
>>> return_dict=True,
>>> ).to(model.device)
>>> outputs = model.generate(**inputs, max_new_tokens=500)
>>> decoded_outputs = processor.batch_decode(
>>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
>>> )
>>> print(decoded_outputs)
["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
```"""
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(input_features, input_features_mask)
# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
)
outputs: CausalLMOutputWithPast = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
use_cache=use_cache,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return outputs
def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage
input_features = kwargs.pop("input_features", None)
input_features_mask = kwargs.pop("input_features_mask", None)
cache_position = kwargs.get("cache_position")
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
if cache_position is not None and cache_position[0] == 0:
# input_features should only be passed when we are not in cached decoding stage
if input_features is not None:
model_inputs["input_features"] = input_features
if input_features_mask is not None:
model_inputs["input_features_mask"] = input_features_mask
return model_inputs
__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]

Some files were not shown because too many files have changed in this diff Show More