Compare commits

...

15 Commits

Author SHA1 Message Date
10baffb599 Multiple llama4 fixe (#37353)
* update for fixes

* more fixes

* fuxix dynamic cache?

* style

* fix both traiining and generating. Eager seems alright

* dynamic does not work

* fix most cases, use_cache or not, eager or not, no default cache (ex: not training but you want to get cache states)

* should be final fixes

* fix more stuff no cat

* style

* fix

* style

* final sytle

* qualityeioiwhjfaopsejdpofqsdjkfjha;wesdhgfkjlqsw.denghjkaswednkgs

* fix

* revert
2025-04-08 11:15:06 +02:00
4a88ffae40 v4.51.1 2025-04-08 00:27:58 +02:00
f19aec737e Fixing flex attention for torch=2.6.0 (#37285)
* adding compile kwarg for torch 2.6

* fixing dynamic

* addressing comment

* typo

* Update src/transformers/integrations/flex_attention.py

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-04-08 00:22:21 +02:00
d8f0695e84 more fixes for post-training llama4 (#37329)
* more fixes for post-training llama4

* use target_length instead of guearded past_key_values
2025-04-08 00:22:17 +02:00
d27c8c38f4 Remove HQQ from caching allocator warmup (#37347)
Update modeling_utils.py
2025-04-08 00:22:07 +02:00
04c0cedcdf fix derived berts _init_weights (#37341)
* fix derived berts

* more

* roformer
2025-04-08 00:21:44 +02:00
4f536ba0ae Fix init empty weights without accelerate (#37337)
* add the integration

* Update accelerate.py

* Update accelerate.py

* add find_tied_params as well

* Update accelerate.py

* add where copied from

* simplify

* add error
2025-04-08 00:21:36 +02:00
6b82af0a5b Fix deepspeed with quantization (#37324)
* Update modeling_utils.py

* Update modeling_utils.py
2025-04-08 00:21:32 +02:00
2bf3d4aca8 fix llama4 training (#37319) 2025-04-08 00:21:28 +02:00
a79b7abede fix flex attn when optional args aren't passed (#37327) 2025-04-08 00:21:24 +02:00
0720e206c6 Release: v4.51.0 2025-04-05 22:03:17 +02:00
25b7f27234 Add llama4 (#37307)
* remove one of the last deps

* update fast image processor after refactor

* styling

* more quality of life improvements

* nit

* update

* cleanups

* some cleanups

* vllm updates

* update fake image token

* [convert] Fix typo

* [convert] Strip extraneous bytes from shards

* [convert] Minor fixes

* [convert] Use num_experts

* multi-image fixes in modeling + processor

* fixup size

* 128 experts

* Use default rope

* Unfuse mlp

* simplify a lot inputs embeds merging

* remove .item() 👀

* fix from review

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* set seed

* return aspect ratios and bug fixes

* Moe 128 rebased (#8)

* 128 experts

* Use default rope

* Unfuse mlp

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* Meta/llama quant compat (#7)

* add quant compatible model & conversion code for llama4

* fix a few issues

* fix a few issues

* minor type mapping fix

---------

Co-authored-by: Lu Fang <fanglu@fb.com>

* use a new config parameter to determine which model definition to use for MoE

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Lu Fang <fanglu@fb.com>

* un-comment write_tokenizer from converting script

* remove un-used imports

* [llama4] Pop aspect_ratios from image processor output in Llama4Processor

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* Fix parameter_count name

* Update src/transformers/models/llama4/configuration_llama4.py

* nit

* Add changes for no_rope, moe_layers, chunked attention. Just need to test all

* Update src/transformers/models/llama4/image_processing_llama4_fast.py

* nit

* fix post merge with main

* support flex attention

* fixes

* fix

* add layer

* small updates

* rebase and delete llm_compressor

* nit

* [llama4/mm] Add back <|image|> token that delimits global tile

* [llama4/mm] Fix Llama 4 image processing unit tests

* add explicit dtype

Signed-off-by: Jon Swenson <jmswen@gmail.com>

* sdpa works

* comment todo small

* fix model loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* revert

* nits

* small fix for TP on 1 node

* Read new params from config

* Add <|eom|>

* lol don't know how this got here

* adding fp8

* Save processor, fix chat template

* style

* Add boi/eoi tokens

We don't use them.

* fixes for now flex seems to work :)

* updates

* nits

* updates

* missking keys

* add context parallel

* update

* update

* fix

* nits

* add worldsize and make eager attn work for vision

* Ignore new key present in base models

* add tp_plan

* fix nope

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* minor fix

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* Clean up Llama4 vision model

* current updates

* add support for `attn_temperature_tuning`

* add floor scale

* add missing attn scales

* push what works, dirty trick for the device synch

* oups

* Fix pad_token_id

See
https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files
Confirmed in the original codebase.

* fix causallml loading

* rm

* fix tied-weights

* fix sdpa

* push current version

* should work with both short and long

* add compressed_tensos & fix fbgemm tp

* Fix flex impl

* style

* chunking

* try to revert the potentially breaking change

* fix auto factory

* fix shapes in general

* rm processing

* commit cache utils cleanup

* Fix context length

* fix

* allocate

* update tp_plan

* fix SDPA!

* Add support for sparse `Llama4TextMoe` layer from the kernel hub

* cleanup

* better merge

* update

* still broken fixing now

* nits

* revert print

* Write max_position_embeddings and max_model_length

* Update modeling_llama4.py

* Save attention_chunk_size

* Sync eos terminators

* Read initializer_range

* style

* remove `dict`

* fix

* eager should use `chunked_attention_mask`

* revert

* fixup

* fix config

* Revert "Merge pull request #36 from huggingface/sparse-llama4-moe"

This reverts commit ccda19f050867dd42ea143c5de60f3dec81375f0, reversing
changes made to a515579aed8c0fe9bf529b6c40446a289406d5d6.

* Fix typo and remove warning with compiled flex and chunked prefill

* Fix MoE vs FF (#41)

* fix

* Use correct no_rope_layers if provided one is empty list

* update tests

* fix

* skipping some tests

* fix fp8 loading

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* fix text geneartion pipeline

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>

* eager needs 4D mask

* fix

* Some cleanup

* fix

* update

* fix

* replace correctly module

* patch

* modulelist

* update

* update

* clean up

* Don't move to `cuda:0` in distributed mode

* restrict to compressed tensors for now

* rm print

* Docs!

* Fixes

* Update docs/source/en/model_doc/llama4.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fixes

* cuda graph fix

* revert some stuff

* fixup

* styling

* Update src/transformers/models/llama4/modeling_llama4.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* commit licence, cleanup here and there and style

* more styling changes

* fix dummies

* fix and clean docstrings

* remove comment

* remove warning

* Only fast image processor is supported

* nit

* trigger CI

* fix issue with flex encoder

* fix dynamic cache

* Code quality

* Code quality

* fix more tests for now

* Code quality

* Code quality

* Nuke bunch of failing stuff

* Code quality

* Code quality

* cleanup removal of slow image processor

* ruff fix fast image processor

* fix

* fix styling

* Docs

* Repo consistency

* Repo consistency

* fix sliding window issue

* separate llama cache

* styling

* Repo consistency

* Repo consistency

* push waht works

* L4 Repo consistency

* Docs

* fix last last alst alst alst alstsaltlsltlaslt

---------

Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: yonigozlan <yoni.gozlan10@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Keyun Tong <tongkeyun@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Jon Swenson <jmswen@gmail.com>
Co-authored-by: jmswen <jmswen@users.noreply.github.com>
Co-authored-by: MekkCyber <mekk.cyber@gmail.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
Co-authored-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: Ye (Charlotte) Qi <ye.charlotte.qi@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-04-05 22:02:22 +02:00
aa40fda346 Hf Xet extra (#37305)
* Hf Xet extra

* Hf Xet extra
2025-04-05 21:06:05 +02:00
e94571580b Fix deepspeed loading (part 2) (#37306)
* fix

* Update modeling_utils.py

* Update modeling_utils.py

* oups remove print
2025-04-05 20:41:42 +02:00
84aa13dd85 Fix deepspeed loading (#37281)
* Update modeling_utils.py

* Update modeling_utils.py

* fix and remove all imports

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py
2025-04-05 17:05:45 +02:00
377 changed files with 5212 additions and 66524 deletions

View File

@ -507,6 +507,8 @@
title: Llama2
- local: model_doc/llama3
title: Llama3
- local: model_doc/llama4
title: Llama4
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5

View File

@ -0,0 +1,442 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Llama4
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
This generation includes two models:
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
-
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
> [!TIP]
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
> model.
>
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
> `pip install transformers[hf_xet]`
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
have context lengths going up to 10 million tokens.
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
from transformers import pipeline
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
messages = [
{"role": "user", "content": "what is the recipe of mayonnaise?"},
]
pipe = pipeline(
"text-generation",
model=model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
output = pipe(messages, do_sample=False, max_new_tokens=200)
print(output[0]["generated_text"][-1]["content"])
```
</hfoption>
<hfoption id="AutoModel - Text only">
```py
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="AutoModel - Multimodal">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": img_url},
{"type": "text", "text": "Describe this image in two sentences."},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Multimodal with multiple images">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": url1},
{"type": "image", "url": url2},
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Long context">
Beware: the example below uses both `device_map="auto"` and flex-attention.
Please use `torchrun` to run this example in tensor-parallel mode.
We will work to enable running with `device_map="auto"` and flex-attention without
tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
import torch
import time
file = "very_long_context_prompt.txt"
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
with open(file, "r") as f:
very_long_text = "\n".join(f.readlines())
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flex_attention",
torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids.to(model.device),
prefill_chunk_size=2048*8,
max_new_tokens=300,
cache_implementation="hybrid",
)
print(time.time()-start)
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
```
</hfoption>
</hfoptions>
## Efficiency; how to get the best out of llama 4
### The Attention methods
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
Switching attention mechanism is done at the model initialization step:
<hfoptions id="Attention">
<hfoption id="Flex Attention">
Setting Flex Attention ensures the best results with the very long context the model can handle.
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
> Please use `torchrun` to run this example in tensor-parallel mode.
>
> We will work to enable running with `device_map="auto"` and flex-attention without
> tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="flex_attention",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="SDPA">
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="sdpa",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="Eager">
The `eager` attention method is set by default, so no need for anything different when loading the model:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
</hfoptions>
### Quantization
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
See below for examples using both:
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
<hfoptions id="Quantization">
<hfoption id="FBGEMM">
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=FbgemmFp8Config()
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="LLM-Compressor">
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
</hfoptions>
### Offloading
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
However, this also slows down inference as it adds communication overhead.
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
## Llama4Config
[[autodoc]] Llama4Config
## Llama4TextConfig
[[autodoc]] Llama4TextConfig
## Llama4VisionConfig
[[autodoc]] Llama4VisionConfig
## Llama4Processor
[[autodoc]] Llama4Processor
## Llama4ImageProcessorFast
[[autodoc]] Llama4ImageProcessorFast
## Llama4ForConditionalGeneration
[[autodoc]] Llama4ForConditionalGeneration
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4TextModel
[[autodoc]] Llama4TextModel
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4VisionModel
[[autodoc]] Llama4VisionModel
- forward

View File

@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)

View File

@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)

View File

@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = logging.getLogger(__name__)

View File

@ -54,7 +54,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = logging.getLogger(__name__)

View File

@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.51.0.dev0")
check_min_version("4.51.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -117,6 +117,7 @@ _deps = [
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"hf_xet",
"huggingface-hub>=0.30.0,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
@ -283,6 +284,7 @@ extras["tf-cpu"] = deps_list(
extras["torch"] = deps_list("torch", "accelerate")
extras["accelerate"] = deps_list("accelerate")
extras["hf_xet"] = deps_list("hf_xet")
if os.name == "nt": # windows
extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows
@ -451,7 +453,7 @@ install_requires = [
setup(
name="transformers",
version="4.51.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.51.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.51.0.dev0"
__version__ = "4.51.1"
from typing import TYPE_CHECKING
@ -562,6 +562,12 @@ _import_structure = {
"models.levit": ["LevitConfig"],
"models.lilt": ["LiltConfig"],
"models.llama": ["LlamaConfig"],
"models.llama4": [
"Llama4Config",
"Llama4Processor",
"Llama4TextConfig",
"Llama4VisionConfig",
],
"models.llava": [
"LlavaConfig",
"LlavaProcessor",
@ -1354,6 +1360,7 @@ else:
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llama4"].append("Llama4ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@ -2510,6 +2517,15 @@ else:
"GlmPreTrainedModel",
]
)
_import_structure["models.llama4"].extend(
[
"Llama4ForCausalLM",
"Llama4ForConditionalGeneration",
"Llama4TextModel",
"Llama4VisionModel",
"Llama4PreTrainedModel",
]
)
_import_structure["models.glpn"].extend(
[
"GLPNForDepthEstimation",
@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
from .models.levit import LevitConfig
from .models.lilt import LiltConfig
from .models.llama import LlamaConfig
from .models.llama4 import (
Llama4Config,
Llama4Processor,
Llama4TextConfig,
Llama4VisionConfig,
)
from .models.llava import (
LlavaConfig,
LlavaProcessor,
@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
from .models.detr import DetrImageProcessorFast
from .models.gemma3 import Gemma3ImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llama4 import Llama4ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
LlamaModel,
LlamaPreTrainedModel,
)
from .models.llama4 import (
Llama4ForCausalLM,
Llama4ForConditionalGeneration,
Llama4PreTrainedModel,
Llama4TextModel,
Llama4VisionModel,
)
from .models.llava import (
LlavaForConditionalGeneration,
LlavaPreTrainedModel,

View File

@ -1811,6 +1811,204 @@ class HybridCache(Cache):
self.value_cache[layer_idx].zero_()
class HybridChunkedCache(Cache):
"""
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
The default `dtype` to use when initializing the layer.
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
```
"""
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
is_compileable = True
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.bfloat16,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
else:
self.sliding_window = config.sliding_window
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype
if hasattr(config.get_text_config(), "no_rope_layers"):
self.is_sliding = config.no_rope_layers
else:
layer_switch = getattr(config, "sliding_window_pattern", 2)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
def initialise_cache_layer(self, layer_idx, key_states):
if len(self.key_cache) > layer_idx:
return
num_key_value_heads = key_states.shape[1]
device = key_states.device
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
num_key_value_heads,
self.sliding_window,
self.head_dim,
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
cache_position = cache_position.clamp(0, max_cache_len - 1)
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
# otherwise we are decoding. Most efficient way to cat 1 token
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
self.initialise_cache_layer(layer_idx, key_states)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if self.is_sliding[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0):
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if layer_idx != 0:
raise ValueError(
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
if len(self.key_cache) == 0:
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.

View File

@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
def to_diff_dict(self) -> dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Removes all attributes from the configuration that correspond to the default config attributes for
better readability, while always retaining the `config` attribute from the class. Serializes to a
Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
config_dict = self.to_dict()
# get the default config dict
# Get the default config dict (from a fresh PreTrainedConfig instance)
default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
# Get class-specific config dict if not part of a composition
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)

View File

@ -24,6 +24,7 @@ deps = {
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"hf_xet": "hf_xet",
"huggingface-hub": "huggingface-hub>=0.30.0,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",

View File

@ -52,6 +52,7 @@ if is_torch_available():
from ..cache_utils import (
HQQQuantizedCache,
HybridCache,
HybridChunkedCache,
MambaCache,
OffloadedStaticCache,
QuantizedCacheConfig,
@ -69,6 +70,7 @@ if is_torch_available():
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"hybrid_chunked": HybridChunkedCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin):
if isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)

View File

@ -1830,6 +1830,9 @@ class GenerationMixin:
Returns the resulting cache object.
"""
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
cache_implementation = "hybrid_chunked"
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
@ -1958,6 +1961,9 @@ class GenerationMixin:
)
generation_config.cache_implementation = None
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
self.config.get_text_config(), "cache_implementation", None
)
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
@ -3405,7 +3411,12 @@ class GenerationMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -4855,6 +4866,45 @@ class GenerationMixin:
else:
return input_ids
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
torch._dynamo.config.cache_size_limit = 64
chunk_size = generation_config.prefill_chunk_size
# Only chunk up the token just before last, so that decoding is completely performed outside this function
# (here we simply prefill the cache)
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
if "past_key_values" not in model_kwargs:
raise ValueError("Cannot use prefill chunkink without a cache")
model_forward = self.get_compiled_call(generation_config.compile_config)
attention_mask = model_kwargs.pop("attention_mask", None)
past_length = 0
for input_chunk in input_chunks:
current_length = past_length + input_chunk.shape[-1]
# Prepare inputs
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
model_kwargs["cache_position"] = torch.arange(
past_length, current_length, dtype=torch.long, device=input_chunk.device
)
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
outputs = model_forward(**model_inputs, return_dict=True)
model_kwargs["past_key_values"] = outputs.past_key_values
past_length = current_length
model_kwargs["attention_mask"] = attention_mask
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
_ = model_kwargs.pop("position_ids", None)
return model_kwargs
def _speculative_sampling(
candidate_input_ids,

View File

@ -53,7 +53,7 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
@ -192,7 +192,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (

View File

@ -0,0 +1,196 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
"""
from contextlib import contextmanager
from ..utils import is_torch_available, logging
if is_torch_available():
import torch
import torch.nn as nn
logger = logging.get_logger(__name__)
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
empty model. Useful when just initializing the model would blow the available RAM.
Args:
include_buffers (`bool`, *optional*):
Whether or not to also put all buffers on the meta device while initializing.
Example:
```python
import torch.nn as nn
from accelerate import init_empty_weights
# Initialize a model with 100 billions parameters in no time and without using any RAM.
with init_empty_weights():
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
```
<Tip warning={true}>
Any model created under this context manager has no weights. As such you can't do something like
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
called.
</Tip>
"""
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
yield f
@contextmanager
def init_on_device(device: "torch.device", include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
include_buffers (`bool`, *optional*):
Whether or not to also put all buffers on the meta device while initializing.
Example:
```python
import torch.nn as nn
from accelerate import init_on_device
with init_on_device(device=torch.device("cuda")):
tst = nn.Linear(100, 100) # on `cuda` device
```
"""
if include_buffers:
with device:
yield
return
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def find_tied_parameters(model: "nn.Module", **kwargs):
"""
Find the tied parameters in a given model.
<Tip warning={true}>
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.
</Tip>
Args:
model (`torch.nn.Module`): The model to inspect.
Returns:
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
```py
>>> from collections import OrderedDict
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# get ALL model parameters and thier names
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
# which names refer to the same parameter. To identify this, we need to group them together.
tied_param_groups = {}
for tied_param_name in tied_param_names:
tied_param = all_named_parameters[tied_param_name]
for param_name, param in no_duplicate_named_parameters.items():
# compare if parameters are the same, if so, group thier names together
if param is tied_param:
if param_name not in tied_param_groups:
tied_param_groups[param_name] = []
tied_param_groups[param_name].append(tied_param_name)
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]

View File

@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
def skip(*args, **kwargs):
pass
class CompressedExpertsLinear(nn.Module):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config):
# Skip random weight initialization for experts. Otherwise,
# the init of this module would take over minutes. For a model
# with tens of layers of experts, it would easily take over 20 minutes.
nn.init.kaiming_uniform_ = skip
nn.init.uniform_ = skip
nn.init.normal_ = skip
super().__init__()
self.num_experts = config.num_local_experts
self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
expert_routed_out_list = []
for expert_idx in range(self.num_experts):
expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
routed_out = torch.cat(expert_routed_out_list, dim=0)
return routed_out

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..activations import ACT2FN
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
logger = logging.get_logger(__name__)
class FbgemmFp8Linear(torch.nn.Module):
class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__()
super().__init__(in_features, out_features, bias)
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
if bias:
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None
def forward(self, x):
num_tokens = None
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
weight_scale_float32 = self.weight_scale.to(torch.float32)
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
return output
class FbgemmFp8Llama4TextExperts(nn.Module):
def __init__(self, config, dtype=torch.float32):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.act_fn = ACT2FN[config.hidden_act]
# Register FP8 buffers for gate_up_proj
self.gate_up_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
)
self.gate_up_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
)
# Register FP8 buffers for down_proj
self.down_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
)
self.down_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
)
# Register input scale upper bound
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
def forward(self, hidden_states):
"""
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
Returns:
torch.Tensor: (batch_size * token_num, hidden_size)
"""
# Reshape hidden states for expert computation
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
num_tokens = None
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
next_states = torch.empty_like(hidden_states)
for i in range(self.num_experts):
# Extract expert's hidden states
expert_hidden = hidden_states[i]
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
# Quantize for this expert
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
expert_hidden_reshaped, num_tokens, self.input_scale_ub
)
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
use_fast_accum=True,
)
up = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
use_fast_accum=True,
)
activated = up * self.act_fn(gate)
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
activated, num_tokens, self.input_scale_ub
)
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
activated_scale,
down_proj_scale_float32[i].view(-1, 1).contiguous(),
use_fast_accum=True,
)
next_states[i] = expert_output
next_states = next_states.to(hidden_states.device)
return next_states.view(-1, self.hidden_size)
def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
import re
if current_key_name is None:
current_key_name = []
@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub],
dtype=torch.float,
)
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
]
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,
)
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
# Remove the last key for recursion
current_key_name.pop(-1)
@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."

View File

@ -31,13 +31,11 @@ from typing import Optional, Tuple, Union
import torch
from ..utils import is_torch_flex_attn_available
from ..utils.import_utils import _torch_version
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
)
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
@ -59,19 +57,36 @@ class WrappedFlexAttention:
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self):
def __init__(self, training):
"""
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
if not self._is_flex_compiled:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
if _torch_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
else:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Offset = Union[torch.Tensor, int]
def make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@ -94,10 +109,18 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
document_ids = attention_mask_2d
batch_size, total_seq_len = document_ids.shape
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
@ -112,18 +135,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
padding_mask = document_ids[batch_idx, q_idx] > 0
return causal_mask & document_mask & padding_mask
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=causal_mask_mod,
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=total_seq_len,
KV_LEN=total_seq_len,
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
@ -132,10 +167,11 @@ def compile_friendly_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
training=False,
**kwargs,
) -> torch.Tensor:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
flex_attention_compiled = WrappedFlexAttention()()
flex_attention_compiled = WrappedFlexAttention(training)()
return flex_attention_compiled(
query,
key,
@ -144,6 +180,18 @@ def compile_friendly_flex_attention(
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
@ -174,17 +222,29 @@ def flex_attention_forward(
score = score + head_mask[batch_idx][head_idx][0][0]
return score
enable_gqa = True
num_local_query_heads = query.shape[1]
# When running TP this helps:
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
kernel_options = kwargs.get("kernel_options", None)
attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
training=module.training,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)

View File

@ -31,7 +31,7 @@ def sdpa_attention_forward(
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions

View File

@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
if slice_dtype == "F8_E4M3":
slice_ = slice_[...].to(torch.float16)
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return tensor
return tensor.to(str_to_torch_dtype[slice_dtype])
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
inputs[0] = inputs[0].to_local()
inputs = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# this op cannot be asynch, otherwise it completely breaks the outputs of models
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
return outputs
@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
mod._bias = mod.bias
mod.bias = None
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
if hasattr(mod, "_bias"):
outputs += mod._bias
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
return nn.Parameter(parameter)
class SequenceParallel(TensorParallelLayer):
"""
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
This style implements the operation that is described in the paper
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
redistribute the input to be sharded on the sequence dimension.
The output of the ``nn.Module`` will be sharded on the sequence dimension.
Keyword Args:
sequence_dim (int, optional):
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
become a DTensor that is sharded on the sequence dimension, default: 1.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
Returns:
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
to ensure that they are replicated.
"""
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = True
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = outputs.redistribute(
placements=(Replicate(),), async_op=True
) # maybe we have to replicate ? because next layer is not sharded
return outputs.to_local() # if use_local_output else outputs
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = param[:]
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
return nn.Parameter(parameter)
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
"local",
"gather",
"local_packed_rowwise",
"sequence_parallel",
}
@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
elif style == "sequence_parallel":
return SequenceParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")
@ -518,6 +633,7 @@ def shard_and_distribute_module(
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
@ -531,12 +647,18 @@ def shard_and_distribute_module(
module_to_tp._is_hooked = True
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()

View File

@ -57,7 +57,8 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
@ -131,12 +132,11 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import add_hook_to_module
from accelerate.utils import (
check_tied_parameters_on_same_device,
extract_model_from_parallel,
find_tied_parameters,
get_balanced_memory,
get_max_memory,
load_offloaded_weights,
@ -153,6 +153,10 @@ if is_safetensors_available():
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
if is_deepspeed_available():
import deepspeed
logger = logging.get_logger(__name__)
@ -480,6 +484,7 @@ str_to_torch_dtype = {
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
if is_torch_greater_or_equal("2.1.0"):
@ -1910,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = (
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
)
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
else:
self._tp_plan = self._tp_plan or {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
@ -2021,8 +2021,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
@ -2662,8 +2660,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
@ -2694,8 +2690,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
@ -2784,8 +2778,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
@ -2830,8 +2822,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
self._init_added_embeddings_weights_with_mean(
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
@ -2847,8 +2837,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
@ -2859,8 +2847,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight = new_embeddings.weight
@ -2918,8 +2904,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
@ -2970,8 +2954,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_lm_head.weight]
if has_new_lm_head_bias:
params += [old_lm_head.bias]
@ -2992,8 +2974,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._copy_lm_head_original_to_resized(
@ -3738,25 +3718,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return super().float(*args)
@classmethod
def get_init_context(
cls: Type[SpecificPreTrainedModelType],
is_quantized=None,
_is_ds_init_called=None,
):
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
no_init_weights(),
]
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
if is_deepspeed_zero3_enabled():
init_contexts = [no_init_weights()]
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
elif is_quantized:
init_contexts.extend([init_empty_weights(), set_quantized_state()])
else:
init_contexts = [no_init_weights(), init_empty_weights()]
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
return init_contexts
@classmethod
@ -4072,6 +4045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
@ -4161,6 +4135,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
if not is_accelerate_available():
raise ValueError(
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
)
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
if load_in_4bit or load_in_8bit:
@ -4256,6 +4234,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
config = hf_quantizer.update_tp_plan(config)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
@ -4388,9 +4367,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
@ -4644,6 +4622,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
# Useful flags
is_quantized = hf_quantizer is not None
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
@ -4805,15 +4788,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
# Warmup cuda to load the weights much faster on devices
if device_map is not None: # and hf_quantizer is None:
if device_map is not None and not is_hqq:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
error_msgs = []
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
@ -4821,7 +4800,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
continue
map_location = "cpu"
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
@ -4843,7 +4826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
@ -4919,7 +4902,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
name,
casting_dtype,
to_contiguous,
tp_device.index,
os.environ["RANK"],
device_mesh,
)
@ -5192,6 +5175,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
@ -5267,8 +5252,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
not_initialized_submodules = dict(self.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(

View File

@ -148,6 +148,7 @@ from . import (
levit,
lilt,
llama,
llama4,
llava,
llava_next,
llava_next_video,

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
import argparse
import torch
from ...utils import logging
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = AlbertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--albert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)

View File

@ -579,6 +579,8 @@ class AlbertPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, AlbertMLMHead):
module.bias.data.zero_()
@dataclass

View File

@ -1,389 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALIGN checkpoints from the original repository."""
import argparse
import os
import align
import numpy as np
import requests
import tensorflow as tf
import torch
from PIL import Image
from tokenizer import Tokenizer
from transformers import (
AlignConfig,
AlignModel,
AlignProcessor,
BertConfig,
BertTokenizer,
EfficientNetConfig,
EfficientNetImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def preprocess(image):
image = tf.image.resize(image, (346, 346))
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
return image
def get_align_config():
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
vision_config.image_size = 289
vision_config.hidden_dim = 640
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
vision_config.depthwise_padding = []
text_config = BertConfig()
config = AlignConfig.from_text_vision_configs(
text_config=text_config, vision_config=vision_config, projection_dim=640
)
return config
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def get_processor():
image_processor = EfficientNetImageProcessor(
do_center_crop=True,
rescale_factor=1 / 127.5,
rescale_offset=True,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokenizer.model_max_length = 64
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
return processor
# here we list all keys to be renamed (original name on the left, our name on the right)
def rename_keys(original_param_names):
# EfficientNet image encoder
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
block_names = list(set(block_names))
block_names = sorted(block_names)
num_blocks = len(block_names)
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
rename_keys = []
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
for b in block_names:
hf_b = block_name_mapping[b]
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
rename_keys.append(
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
)
rename_keys.append(
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
)
rename_keys.append(
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
)
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
rename_keys.append(
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
)
rename_keys.append(
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
)
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
rename_keys.append(
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
)
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
rename_keys.append(
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
)
rename_keys.append(
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
)
key_mapping = {}
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = "vision_model." + item[1]
# BERT text encoder
rename_keys = []
old = "tf_bert_model/bert"
new = "text_model"
for i in range(12):
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.query.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
f"{new}.encoder.layer.{i}.attention.self.query.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.key.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
f"{new}.encoder.layer.{i}.attention.self.key.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.value.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
f"{new}.encoder.layer.{i}.attention.self.value.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
)
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
)
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
rename_keys.append(
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
)
rename_keys.append(
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
)
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("temperature:0", "temperature"))
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = item[1]
return key_mapping
def replace_params(hf_params, tf_params, key_mapping):
list(hf_params.keys())
for key, value in tf_params.items():
if key not in key_mapping:
continue
hf_key = key_mapping[key]
if "_conv" in key and "kernel" in key:
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
elif "embeddings" in key:
new_hf_value = torch.from_numpy(value)
elif "depthwise_kernel" in key:
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
elif "kernel" in key:
new_hf_value = torch.from_numpy(np.transpose(value))
elif "temperature" in key:
new_hf_value = value
elif "bn/gamma" or "bn/beta" in key:
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
else:
new_hf_value = torch.from_numpy(value)
# Replace HF parameters with original TF model parameters
hf_params[hf_key].copy_(new_hf_value)
@torch.no_grad()
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
"""
Copy/paste/tweak model's weights to our ALIGN structure.
"""
# Load original model
seq_length = 64
tok = Tokenizer(seq_length)
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
original_model.compile()
original_model.load_weights(checkpoint_path)
tf_params = original_model.trainable_variables
tf_non_train_params = original_model.non_trainable_variables
tf_params = {param.name: param.numpy() for param in tf_params}
for param in tf_non_train_params:
tf_params[param.name] = param.numpy()
tf_param_names = list(tf_params.keys())
# Load HuggingFace model
config = get_align_config()
hf_model = AlignModel(config).eval()
hf_params = hf_model.state_dict()
# Create src-to-dst parameter name mapping dictionary
print("Converting parameters...")
key_mapping = rename_keys(tf_param_names)
replace_params(hf_params, tf_params, key_mapping)
# Initialize processor
processor = get_processor()
inputs = processor(
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
)
# HF model inference
hf_model.eval()
with torch.no_grad():
outputs = hf_model(**inputs)
hf_image_features = outputs.image_embeds.detach().numpy()
hf_text_features = outputs.text_embeds.detach().numpy()
# Original model inference
original_model.trainable = False
tf_image_processor = EfficientNetImageProcessor(
do_center_crop=True,
do_rescale=False,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
text = tok(tf.constant(["A picture of a cat"]))
image_features = original_model.image_encoder(image, training=False)
text_features = original_model.text_encoder(text, training=False)
image_features = tf.nn.l2_normalize(image_features, axis=-1)
text_features = tf.nn.l2_normalize(text_features, axis=-1)
# Check whether original and HF model outputs match -> np.allclose
if not np.allclose(image_features, hf_image_features, atol=1e-3):
raise ValueError("The predicted image features are not the same.")
if not np.allclose(text_features, hf_text_features, atol=1e-3):
raise ValueError("The predicted text features are not the same.")
print("Model outputs match!")
if save_model:
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
# Save converted model and image processor
hf_model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Push model and image processor to hub
print("Pushing converted ALIGN to the hub...")
processor.push_to_hub("align-base")
hf_model.push_to_hub("align-base")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_path",
default="./weights/model-weights",
type=str,
help="Path to the pretrained TF ALIGN checkpoint.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="hf_model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
args = parser.parse_args()
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)

View File

@ -1,162 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
AddedToken,
AriaForConditionalGeneration,
AriaProcessor,
AutoConfig,
AutoTokenizer,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
Example for creating the old state dict file with Python:
import torch
from aria.model.language_model.aria_llama import AriaTextForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"vision_tower.vision_model": "vision_tower",
"ln_ffn": "layer_norm",
"ffn": "feed_forward",
"ln_kv": "layer_norm_kv",
}
def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
return new_state_dict
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
extra_special_tokens={
"image_token": "<|img|>",
"pad_token": "<pad>",
},
)
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
processor = AriaProcessor.from_pretrained(
text_model_id,
tokenizer=tokenizer,
)
config = AutoConfig.from_pretrained(text_model_id)
config.vision_config.hidden_size = 1152
config.vision_config.attention_heads = 16
config.pad_token_id = 2
config.image_token_index = 9
config.intermediate_size = config.moe_intermediate_size
config.auto_map = {
"AutoConfig": "modeling_aria.AriaConfig",
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
}
with torch.device("meta"):
model = AriaForConditionalGeneration(config)
state_dict = load_original_state_dict(old_state_dict_id)
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)
# print("Saving models")
# model.save_pretrained("local_aria", safe_serialization=False)
# processor.save_pretrained("local_aria")
print("Pushing to hub")
model.push_to_hub(output_hub_path, create_pr=True)
processor.push_to_hub(output_hub_path, create_pr=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
default="rhymes-ai/Aria",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="rhymes-ai/Aria",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="rhymes-ai/Aria",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="rhymes-ai/Aria",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
if __name__ == "__main__":
main()

View File

@ -1,279 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
import argparse
import json
from pathlib import Path
import torch
import torchaudio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_audio_spectrogram_transformer_config(model_name):
config = ASTConfig()
if "10-10" in model_name:
pass
elif "speech-commands" in model_name:
config.max_length = 128
elif "12-12" in model_name:
config.time_stride = 12
config.frequency_stride = 12
elif "14-14" in model_name:
config.time_stride = 14
config.frequency_stride = 14
elif "16-16" in model_name:
config.time_stride = 16
config.frequency_stride = 16
else:
raise ValueError("Model not supported")
repo_id = "huggingface/label-files"
if "speech-commands" in model_name:
config.num_labels = 35
filename = "speech-commands-v2-id2label.json"
else:
config.num_labels = 527
filename = "audioset-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def rename_key(name):
if "module.v" in name:
name = name.replace("module.v", "audio_spectrogram_transformer")
if "cls_token" in name:
name = name.replace("cls_token", "embeddings.cls_token")
if "dist_token" in name:
name = name.replace("dist_token", "embeddings.distillation_token")
if "pos_embed" in name:
name = name.replace("pos_embed", "embeddings.position_embeddings")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
# transformer blocks
if "blocks" in name:
name = name.replace("blocks", "encoder.layer")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
# final layernorm
if "audio_spectrogram_transformer.norm" in name:
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
# classifier head
if "module.mlp_head.0" in name:
name = name.replace("module.mlp_head.0", "classifier.layernorm")
if "module.mlp_head.1" in name:
name = name.replace("module.mlp_head.1", "classifier.dense")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[3])
dim = config.hidden_size
if "weight" in key:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
] = val[:dim, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
] = val[dim : dim * 2, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
] = val[:dim]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
] = val[dim : dim * 2]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def remove_keys(state_dict):
ignore_keys = [
"module.v.head.weight",
"module.v.head.bias",
"module.v.head_dist.weight",
"module.v.head_dist.bias",
]
for k in ignore_keys:
state_dict.pop(k, None)
@torch.no_grad()
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
"""
config = get_audio_spectrogram_transformer_config(model_name)
model_name_to_url = {
"ast-finetuned-audioset-10-10-0.4593": (
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.450": (
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448": (
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448-v2": (
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
),
"ast-finetuned-audioset-12-12-0.447": (
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
),
"ast-finetuned-audioset-14-14-0.443": (
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
),
"ast-finetuned-audioset-16-16-0.442": (
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
),
"ast-finetuned-speech-commands-v2": (
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
),
}
# load original state_dict
checkpoint_url = model_name_to_url[model_name]
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
# remove some keys
remove_keys(state_dict)
# rename some keys
new_state_dict = convert_state_dict(state_dict, config)
# load 🤗 model
model = ASTForAudioClassification(config)
model.eval()
model.load_state_dict(new_state_dict)
# verify outputs on dummy input
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
max_length = 1024 if "speech-commands" not in model_name else 128
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
if "speech-commands" in model_name:
# TODO: Convert dataset to Parquet
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
waveform = dataset[0]["audio"]["array"]
else:
filepath = hf_hub_download(
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
filename="sample_audio.flac",
repo_type="dataset",
)
waveform, _ = torchaudio.load(filepath)
waveform = waveform.squeeze().numpy()
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
if model_name == "ast-finetuned-audioset-10-10-0.4593":
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
elif model_name == "ast-finetuned-audioset-10-10-0.450":
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
elif model_name == "ast-finetuned-audioset-10-10-0.448":
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
elif model_name == "ast-finetuned-audioset-12-12-0.447":
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
elif model_name == "ast-finetuned-audioset-14-14-0.443":
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
elif model_name == "ast-finetuned-audioset-16-16-0.442":
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
elif model_name == "ast-finetuned-speech-commands-v2":
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
else:
raise ValueError("Unknown model name")
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
raise ValueError("Logits don't match")
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and feature extractor to the hub...")
model.push_to_hub(f"MIT/{model_name}")
feature_extractor.push_to_hub(f"MIT/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ast-finetuned-audioset-10-10-0.4593",
type=str,
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -544,10 +544,6 @@ class _BaseAutoModelClass:
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
@ -570,6 +566,8 @@ class _BaseAutoModelClass:
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
if model_class.config_class == config.sub_configs.get("text_config", None):
config = config.get_text_config()
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)

View File

@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("levit", "LevitConfig"),
("lilt", "LiltConfig"),
("llama", "LlamaConfig"),
("llama4", "Llama4Config"),
("llama4_text", "Llama4TextConfig"),
("llava", "LlavaConfig"),
("llava_next", "LlavaNextConfig"),
("llava_next_video", "LlavaNextVideoConfig"),
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("llama", "LLaMA"),
("llama2", "Llama2"),
("llama3", "Llama3"),
("llama4", "Llama4"),
("llama4_text", "Llama4ForCausalLM"),
("llava", "LLaVa"),
("llava_next", "LLaVA-NeXT"),
("llava_next_video", "LLaVa-NeXT-Video"),
@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
]
)

View File

@ -104,6 +104,7 @@ else:
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),

View File

@ -17,7 +17,6 @@
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("levit", "LevitModel"),
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("llama4", "Llama4ForCausalLM"),
("llama4_text", "Llama4ForCausalLM"),
("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"),
@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("llama4", "Llama4VisionModel"),
("mllama", "MllamaVisionModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3TextModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("llama4", "Llama4TextModel"),
("longformer", "LongformerModel"),
("mllama", "MllamaTextModel"),
("mobilebert", "MobileBertModel"),
@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
@classmethod
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
"""
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
a nested text decoder section, uses that section instead.
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
"""
possible_text_config_names = ("decoder", "generator", "text_config")
text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(config, text_config_name):
text_config_names += [text_config_name]
text_config = config.get_text_config(decoder=True)
if text_config_names and type(text_config) in cls._model_mapping.keys():
warnings.warn(
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
FutureWarning,
)
return config
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("llama4", "Llama4Processor"),
("llava", "LlavaProcessor"),
("llava_next", "LlavaNextProcessor"),
("llava_next_video", "LlavaNextVideoProcessor"),

View File

@ -292,6 +292,20 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4_text",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),

View File

@ -1,273 +0,0 @@
# coding=utf-8
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse
import json
import os
import re
from os import path
from typing import Dict, Optional, Union
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from transformers import AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from .configuration_bamba import BambaConfig
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
state_dict = {}
for orig_k, param in original_sd.items():
k = orig_k.replace("backbone", "model")
# for embeddings
k = k.replace("embedding", "embed_tokens")
# for mixer
k = k.replace("mixer", "mamba")
# for final layernorm
k = k.replace("norm_f", "final_layernorm")
# for block layernorm
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
# for mlp
k = k.replace("mlp.fc2", "feed_forward.down_proj")
if "mlp.fc1" in k:
param, param2 = torch.chunk(param, 2, dim=0)
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
state_dict[k2] = param2
k = k.replace("mlp.fc1", "feed_forward.up_proj")
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
):
# then this must be a mamba
pass
else:
# for attn
# - because mixer was replaced to mamba above
k = k.replace("mamba.out_proj", "self_attn.o_proj")
if "mamba.in_proj" in k:
m, n = param.shape
d = (m - n) // 2
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
state_dict[k2] = param2
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
state_dict[k2] = param3
k = k.replace("mamba.in_proj", "self_attn.q_proj")
state_dict[k] = param
return state_dict
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_ssm_config_to_hf_config(
config_ssm: Dict,
**kwargs,
) -> BambaConfig:
"""Convert a config from mamba_ssm to a BambaConfig from here."""
hf_config: BambaConfig = BambaConfig(**kwargs)
hf_config.architectures = ["BambaForCausalLM"]
# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm["d_model"]
hf_config.intermediate_size = config_ssm["d_intermediate"]
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
hf_config.num_hidden_layers = config_ssm["n_layer"]
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
# currently this script assumes config_ssm belongs to v2
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
raise ValueError("Conversion script only supports Mamba2")
# Set attention values
attn_cfg = config_ssm.get("attn_cfg")
if attn_cfg:
assert attn_cfg["causal"], "Only support non-causal attention."
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
hf_config.num_attention_heads = attn_cfg["num_heads"]
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
attention_layer_indices = config_ssm.get("attn_layer_idx")
if attention_layer_indices:
hf_config.attn_layer_indices = attention_layer_indices
# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size
return hf_config
def save_single_safetensor(
state_dict: Dict,
save_directory: str,
metadata: Dict,
):
save_file(
state_dict,
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
metadata,
)
def save_sharded_safetensors(
state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
mamba_ssm_checkpoint_path: str,
precision: str,
output_dir: str,
tokenizer_path: Optional[str] = None,
save_model: Union[bool, str] = True,
) -> None:
# load tokenizer if provided, this will be used to set the
# token_ids in the config file
token_ids = {}
if tokenizer_path:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
for key in [
"bos_token_id",
"eos_token_id",
"pad_token_id",
]:
id = getattr(tokenizer, key, None)
if id:
token_ids[key] = id
# there are some configs unsettable by mamba_ssn config, so
# if there are changes from the defaults, have to pass them into
# the function
unsettables = {
"mamba_d_head": 64,
"mamba_d_state": 128,
"mamba_n_groups": 1,
"rms_norm_eps": 1e-5,
}
# Load and save config based on name
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
# convert the config
hf_config = convert_ssm_config_to_hf_config(
config_ssm=config,
**token_ids,
**unsettables,
)
hf_config.save_pretrained(output_dir)
# Load state dict of the original model and transfer to hf model
state_dict = torch.load(
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)
# FIXME: allow other parameters to pass in
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
# Save new model to pytorch_dump_path
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_file_fn = None
if isinstance(save_model, bool) and save_model:
save_file_fn = save_single_safetensor
elif isinstance(save_model, str) and save_model == "sharded":
save_file_fn = save_sharded_safetensors
if save_file_fn:
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba_ssm_checkpoint_directory",
type=str,
required=True,
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
const="fp16",
required=True,
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a the tokenizer file.",
)
args = parser.parse_args()
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_directory,
args.precision,
args.output_dir,
)

View File

@ -1,263 +0,0 @@
"""Convert Bark checkpoint."""
import argparse
import os
from pathlib import Path
import torch
from bark.generation import _load_model as _bark_load_model
from huggingface_hub import hf_hub_download
from transformers import EncodecConfig, EncodecModel, set_seed
from transformers.models.bark.configuration_bark import (
BarkCoarseConfig,
BarkConfig,
BarkFineConfig,
BarkSemanticConfig,
)
from transformers.models.bark.generation_configuration_bark import (
BarkCoarseGenerationConfig,
BarkFineGenerationConfig,
BarkGenerationConfig,
BarkSemanticGenerationConfig,
)
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
set_seed(770)
new_layer_name_dict = {
"c_attn": "att_proj",
"c_proj": "out_proj",
"c_fc": "in_proj",
"transformer.": "",
"h.": "layers.",
"ln_1": "layernorm_1",
"ln_2": "layernorm_2",
"ln_f": "layernorm_final",
"wpe": "position_embeds_layer",
"wte": "input_embeds_layer",
}
REMOTE_MODEL_PATHS = {
"text_small": {
"repo_id": "suno/bark",
"file_name": "text.pt",
},
"coarse_small": {
"repo_id": "suno/bark",
"file_name": "coarse.pt",
},
"fine_small": {
"repo_id": "suno/bark",
"file_name": "fine.pt",
},
"text": {
"repo_id": "suno/bark",
"file_name": "text_2.pt",
},
"coarse": {
"repo_id": "suno/bark",
"file_name": "coarse_2.pt",
},
"fine": {
"repo_id": "suno/bark",
"file_name": "fine_2.pt",
},
}
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
def _get_ckpt_path(model_type, use_small=False):
key = model_type
if use_small:
key += "_small"
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
def _download(from_hf_path, file_name):
os.makedirs(CACHE_DIR, exist_ok=True)
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if model_type == "text":
ModelClass = BarkSemanticModel
ConfigClass = BarkSemanticConfig
GenerationConfigClass = BarkSemanticGenerationConfig
elif model_type == "coarse":
ModelClass = BarkCoarseModel
ConfigClass = BarkCoarseConfig
GenerationConfigClass = BarkCoarseGenerationConfig
elif model_type == "fine":
ModelClass = BarkFineModel
ConfigClass = BarkFineConfig
GenerationConfigClass = BarkFineGenerationConfig
else:
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"])
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
# convert Bark model arguments to HF Bark model arguments
model_args["num_heads"] = model_args.pop("n_head")
model_args["hidden_size"] = model_args.pop("n_embd")
model_args["num_layers"] = model_args.pop("n_layer")
model_config = ConfigClass(**checkpoint["model_args"])
model = ModelClass(config=model_config)
model_generation_config = GenerationConfigClass()
model.generation_config = model_generation_config
state_dict = checkpoint["model"]
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
# replace part of the key with corresponding layer name in HF implementation
new_k = k[len(unwanted_prefix) :]
for old_layer_name in new_layer_name_dict:
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
state_dict[new_k] = state_dict.pop(k)
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.num_parameters(exclude_embeddings=True)
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
return model
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
device = "cpu" # do conversion on cpu
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
# load bark initial model
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
if model_type == "text":
bark_model = bark_model["model"]
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
raise ValueError("initial and new models don't have the same number of parameters")
# check if same output as the bark model
batch_size = 5
sequence_length = 10
if model_type in ["text", "coarse"]:
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
output_old_model = bark_model(vec)[0]
output_new_model_total = model(vec)
# take last logits
output_new_model = output_new_model_total.logits[:, [-1], :]
else:
prediction_codebook_channel = 3
n_codes_total = 8
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
output_new_model_total = model(prediction_codebook_channel, vec)
output_old_model = bark_model(prediction_codebook_channel, vec)
output_new_model = output_new_model_total.logits
# output difference should come from the difference of self-attention implementation design
if output_new_model.shape != output_old_model.shape:
raise ValueError("initial and new outputs don't have the same shape")
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
raise ValueError("initial and new outputs are not equal")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
def load_whole_bark_model(
semantic_path,
coarse_path,
fine_path,
append_text,
hub_path,
folder_path,
):
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
semantic = BarkSemanticModel.from_pretrained(semantic_path)
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
bark_config = BarkConfig.from_sub_model_configs(
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
)
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
)
bark = BarkModel(bark_config)
bark.semantic = semantic
bark.coarse_acoustics = coarseAcoustic
bark.fine_acoustics = fineAcoustic
bark.codec_model = codec
bark.generation_config = bark_generation_config
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
args = parser.parse_args()
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)

View File

@ -1,156 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BART checkpoint."""
import argparse
import os
from pathlib import Path
import fairseq
import torch
from packaging import version
from torch import nn
from transformers import (
BartConfig,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_TEXT = " Hello world! cécé herlolip"
mnli_rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu")
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
if not os.path.exists(checkpoint_path):
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)
bart.model.upgrade_state_dict(bart.model.state_dict())
if hf_checkpoint_name is None:
hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
if not torch.eq(tokens, tokens2).all():
raise ValueError(
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
)
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in mnli_rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "facebook/bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Check results
if fairseq_output.shape != new_model_outputs.shape:
raise ValueError(
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
)
if (fairseq_output != new_model_outputs).any().item():
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
)
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)

View File

@ -1,373 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BEiT checkpoints from the unilm repository."""
import argparse
import json
from pathlib import Path
import requests
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
BeitConfig,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitImageProcessor,
)
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
)
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
]
)
if has_lm_head:
# mask token + shared relative position bias + layernorm
rename_keys.extend(
[
("mask_token", "beit.embeddings.mask_token"),
(
"rel_pos_bias.relative_position_bias_table",
"beit.encoder.relative_position_bias.relative_position_bias_table",
),
(
"rel_pos_bias.relative_position_index",
"beit.encoder.relative_position_bias.relative_position_index",
),
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else:
# layernorm + classification head
rename_keys.extend(
[
("fc_norm.weight", "beit.pooler.layernorm.weight"),
("fc_norm.bias", "beit.pooler.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers):
prefix = "backbone." if is_semantic else ""
# queries, keys and values
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
# relative_position bias table + index
if not has_lm_head:
# each layer has its own relative position bias
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BEiT structure.
"""
# define default BEiT configuration
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
# masked image modeling
config.use_shared_relative_position_bias = True
config.use_mask_token = True
has_lm_head = True
elif checkpoint_url[-9:-4] == "ft22k":
# intermediate fine-tuning on ImageNet-22k
config.use_relative_position_bias = True
config.num_labels = 21841
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
# this dataset contains 21843 labels but the model only has 21841
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
del id2label[9205]
del id2label[15027]
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
elif checkpoint_url[-8:-4] == "to1k":
# fine-tuning on ImageNet-1k
config.use_relative_position_bias = True
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "384" in checkpoint_url:
config.image_size = 384
if "512" in checkpoint_url:
config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
# size of the architecture
if "base" in checkpoint_url:
pass
elif "large" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val
# load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else:
model = BeitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# Check outputs on an image
if is_semantic:
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
image = Image.open(ds[0]["file"])
else:
image_processor = BeitImageProcessor(
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
)
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
logits = outputs.logits
# verify logits
expected_shape = torch.Size([1, 1000])
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
expected_class_idx = 2397
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
expected_class_idx = 2396
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
expected_class_idx = 285
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
expected_class_idx = 281
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else:
raise ValueError("Can't verify logits as model is not supported")
if logits.shape != expected_shape:
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
if not has_lm_head:
if is_semantic:
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
else:
print("Predicted class idx:", logits.argmax(-1).item())
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
if logits.argmax(-1).item() != expected_class_idx:
raise ValueError("Predicted class index not as expected")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
type=str,
help="URL to the original PyTorch checkpoint (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

View File

@ -1,246 +0,0 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
weight names to the original names, so the model can be imported with Huggingface/transformer.
You may adapt this script to include classification/MLM/NSP/etc. heads.
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
Models trained with never versions are not compatible with this script.
"""
import argparse
import os
import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info(f"Loading TF weight {name} with shape {shape}")
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
" heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError(f"Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
f" {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
# Instantiate model
logger.info(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertModel(config)
# Load weights from checkpoint
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
# Save pytorch-model
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,112 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import argparse
import os
import numpy as np
import tensorflow as tf
import torch
from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
"""
Args:
model: BertModel Pytorch model instance to be converted
ckpt_dir: Tensorflow model directory
model_name: model name
Currently supported HF models:
- Y BertModel
- N BertForMaskedLM
- N BertForPreTraining
- N BertForMultipleChoice
- N BertForNextSentencePrediction
- N BertForSequenceClassification
- N BertForQuestionAnswering
"""
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
state_dict = model.state_dict()
def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return f"bert/{name}"
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var
tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
if any(x in var_name for x in tensors_to_transpose):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
tf_weight = session.run(tf_var)
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
parser.add_argument(
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
)
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__":
main()

View File

@ -1,188 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
"""
import argparse
import tensorflow as tf
import torch
from transformers import BertConfig, BertForMaskedLM
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertPooler,
BertSelfAttention,
BertSelfOutput,
)
from transformers.utils import logging
logging.set_verbosity_info()
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
def get_masked_lm_array(name: str):
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_array(name: str):
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_layer_array(layer_index: int, name: str):
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
array = array.reshape(orginal_shape)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
print(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertForMaskedLM(config)
# Layers
for layer_index in range(0, config.num_hidden_layers):
layer: BertLayer = model.bert.encoder.layer[layer_index]
# Self-attention
self_attn: BertSelfAttention = layer.attention.self
self_attn.query.weight.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
)
self_attn.query.bias.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
)
self_attn.key.weight.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
)
self_attn.key.bias.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
)
self_attn.value.weight.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
)
self_attn.value.bias.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
)
# Self-attention Output
self_output: BertSelfOutput = layer.attention.output
self_output.dense.weight.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
)
self_output.dense.bias.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
)
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
# Intermediate
intermediate: BertIntermediate = layer.intermediate
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
# Output
bert_output: BertOutput = layer.output
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
# Embeddings
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
# LM Head
lm_head = model.cls.predictions.transform
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
# Pooling
model.bert.pooler = BertPooler(config=config)
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
# Export final model
model.save_pretrained(pytorch_dump_path)
# Integration test - should load without any errors ;)
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
print(new_model.eval())
print("Model conversion was done sucessfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -608,6 +608,8 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, BertGenerationOnlyLMHead):
module.bias.data.zero_()
BERT_GENERATION_START_DOCSTRING = r"""

View File

@ -1,69 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigBird checkpoint."""
import argparse
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
# Initialise PyTorch model
config = BigBirdConfig.from_json_file(big_bird_config_file)
print(f"Building PyTorch model from configuration: {config}")
if is_trivia_qa:
model = BigBirdForQuestionAnswering(config)
else:
model = BigBirdForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--big_bird_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
)

View File

@ -1769,6 +1769,8 @@ class BigBirdPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, BigBirdLMPredictionHead):
module.bias.data.zero_()
BIG_BIRD_START_DOCSTRING = r"""

View File

@ -1,170 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Dict
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
INIT_COMMON = [
# tf -> hf
("/", "."),
("layer_", "layers."),
("kernel", "weight"),
("beta", "bias"),
("gamma", "weight"),
("pegasus", "model"),
]
END_COMMON = [
(".output.dense", ".fc2"),
("intermediate.LayerNorm", "final_layer_norm"),
("intermediate.dense", "fc1"),
]
DECODER_PATTERNS = (
INIT_COMMON
+ [
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.out_proj"),
("attention.self", "self_attn"),
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
("attention.encdec_output.dense", "encoder_attn.out_proj"),
("attention.encdec", "encoder_attn"),
("key", "k_proj"),
("value", "v_proj"),
("query", "q_proj"),
("decoder.LayerNorm", "decoder.layernorm_embedding"),
]
+ END_COMMON
)
REMAINING_PATTERNS = (
INIT_COMMON
+ [
("embeddings.word_embeddings", "shared.weight"),
("embeddings.position_embeddings", "embed_positions.weight"),
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.output"),
("attention.self", "self_attn.self"),
("encoder.LayerNorm", "encoder.layernorm_embedding"),
]
+ END_COMMON
)
KEYS_TO_IGNORE = [
"encdec/key/bias",
"encdec/query/bias",
"encdec/value/bias",
"self/key/bias",
"self/query/bias",
"self/value/bias",
"encdec_output/dense/bias",
"attention/output/dense/bias",
]
def rename_state_dict_key(k, patterns):
for tf_name, hf_name in patterns:
k = k.replace(tf_name, hf_name)
return k
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
cfg = BigBirdPegasusConfig(**config_update)
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
state_dict = torch_model.state_dict()
mapping = {}
# separating decoder weights
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = DECODER_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = REMAINING_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
if k != "pegasus/embeddings/position_embeddings":
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
missing, extra = torch_model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k
for k in missing
if k
not in [
"final_logits_bias",
"model.encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"lm_head.weight",
]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return torch_model
def get_tf_weights_as_numpy(path) -> Dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any(pat in name for pat in ignore_name)
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
tf_weights = get_tf_weights_as_numpy(ckpt_path)
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
config_update = {}
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)

View File

@ -1,292 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import shutil
import torch
from transformers import BioGptConfig, BioGptForCausalLM
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from transformers.utils import WEIGHTS_NAME, logging
logging.set_verbosity_warning()
json_indent = 2
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def _load_meta(self, lines):
return 0
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
def rewrite_dict_keys(d):
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
keep_keys = "<s> <pad> </s> <unk>".split()
# restore the special tokens
for k in keep_keys:
del d2[f"{k}</w>"]
d2[k] = d[k] # restore
return d2
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
# prep
if not os.path.exists(biogpt_checkpoint_path):
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
print(f"Writing results to {pytorch_dump_folder_path}")
# handle various types of models
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
if not os.path.isfile(checkpoint_file):
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
chkpt = torch.load(checkpoint_file, map_location="cpu")
args = chkpt["cfg"]["model"]
# dicts
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
if not os.path.isfile(dict_file):
raise ValueError(f"path to the file {dict_file} does not exist!")
src_dict = Dictionary.load(dict_file)
src_vocab = rewrite_dict_keys(src_dict.indices)
src_vocab_size = len(src_vocab)
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
# merges_file (bpecodes)
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
if not os.path.isfile(bpecodes_file):
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
shutil.copyfile(bpecodes_file, merges_file)
# model config
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
model_conf = {
"activation_dropout": args["activation_dropout"],
"architectures": ["BioGptForCausalLM"],
"attention_probs_dropout_prob": args["attention_dropout"],
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_act": args["activation_fn"],
"hidden_dropout_prob": args["dropout"],
"hidden_size": args["decoder_embed_dim"],
"initializer_range": 0.02,
"intermediate_size": args["decoder_ffn_embed_dim"],
"layer_norm_eps": 1e-12,
"layerdrop": args["decoder_layerdrop"],
"max_position_embeddings": args["max_target_positions"],
"model_type": "biogpt",
"num_attention_heads": args["decoder_attention_heads"],
"num_hidden_layers": args["decoder_layers"],
"pad_token_id": 1,
"scale_embedding": not args["no_scale_embedding"],
"tie_word_embeddings": args["share_decoder_input_output_embed"],
"vocab_size": src_vocab_size,
}
# good hparam defaults to start with
print(f"Generating {biogpt_model_config_file}")
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
# tokenizer config
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
tokenizer_conf = {
"bos_token": "<s>",
"eos_token": "</s>",
"model_max_length": 1024,
"pad_token": "<pad>",
"special_tokens_map_file": None,
"tokenizer_class": "BioGptTokenizer",
"unk_token": "<unk>",
}
print(f"Generating {biogpt_tokenizer_config_file}")
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
# model
model_state_dict = chkpt["model"]
# remove unneeded keys
ignore_keys = [
"decoder.version",
]
for k in ignore_keys:
model_state_dict.pop(k, None)
layer_names = list(model_state_dict.keys())
for layer_name in layer_names:
if layer_name.endswith("output_projection.weight"):
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
else:
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
model_new = BioGptForCausalLM(config)
# check that it loads ok
model_new.load_state_dict(model_state_dict)
# save
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
print(f"Generating {pytorch_weights_dump_path}")
torch.save(model_state_dict, pytorch_weights_dump_path)
print("Conversion is done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--biogpt_checkpoint_path",
default=None,
type=str,
required=True,
help=(
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
" bpecodes, etc."
),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -1,177 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BiT checkpoints from the timm library."""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_config(model_name):
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
conv_layer = "std_conv" if "bit" in model_name else False
# note that when using BiT as backbone for ViT-hybrid checkpoints,
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
# config.conv_layer = "std_conv_same"
config = BitConfig(
conv_layer=conv_layer,
num_labels=1000,
id2label=id2label,
label2id=label2id,
)
return config
def rename_key(name):
if "stem.conv" in name:
name = name.replace("stem.conv", "bit.embedder.convolution")
if "blocks" in name:
name = name.replace("blocks", "layers")
if "head.fc" in name:
name = name.replace("head.fc", "classifier.1")
if name.startswith("norm"):
name = "bit." + name
if "bit" not in name and "classifier" not in name:
name = "bit.encoder." + name
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our BiT structure.
"""
# define default BiT configuration
config = get_config(model_name)
# load original model from timm
timm_model = create_model(model_name, pretrained=True)
timm_model.eval()
# load state_dict of original model
state_dict = timm_model.state_dict()
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
# load HuggingFace model
model = BitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# create image processor
transform = create_transform(**resolve_data_config({}, model=timm_model))
timm_transforms = transform.transforms
pillow_resamplings = {
"bilinear": PILImageResampling.BILINEAR,
"bicubic": PILImageResampling.BICUBIC,
"nearest": PILImageResampling.NEAREST,
}
processor = BitImageProcessor(
do_resize=True,
size={"shortest_edge": timm_transforms[0].size},
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
do_center_crop=True,
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
do_normalize=True,
image_mean=timm_transforms[-1].mean.tolist(),
image_std=timm_transforms[-1].std.tolist(),
)
image = prepare_img()
timm_pixel_values = transform(image).unsqueeze(0)
pixel_values = processor(image, return_tensors="pt").pixel_values
# verify pixel values
assert torch.allclose(timm_pixel_values, pixel_values)
# verify logits
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
print("Logits:", logits[0, :3])
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model {model_name} and processor to the hub")
model.push_to_hub(f"ybelkada/{model_name}")
processor.push_to_hub(f"ybelkada/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="resnetv2_50x1_bitm",
type=str,
help="Name of the BiT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub.",
)
args = parser.parse_args()
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,114 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Blenderbot checkpoint."""
import argparse
import torch
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
PATTERNS = [
["attention", "attn"],
["encoder_attention", "encoder_attn"],
["q_lin", "q_proj"],
["k_lin", "k_proj"],
["v_lin", "v_proj"],
["out_lin", "out_proj"],
["norm_embeddings", "layernorm_embedding"],
["position_embeddings", "embed_positions"],
["embeddings", "embed_tokens"],
["ffn.lin", "fc"],
]
def rename_state_dict_key(k):
if k == "embeddings.weight":
return "shared.weight"
for parlai_name, hf_name in PATTERNS:
k = k.replace(parlai_name, hf_name)
if k.startswith("encoder"):
k = k.replace(".attn", ".self_attn")
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "final_layer_norm")
elif k.startswith("decoder"):
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "encoder_attn_layer_norm")
k = k.replace("norm3", "final_layer_norm")
return k
def rename_layernorm_keys(sd):
keys = [
"model.encoder.layernorm_embedding.weight",
"model.encoder.layernorm_embedding.bias",
"model.decoder.layernorm_embedding.weight",
"model.decoder.layernorm_embedding.bias",
]
for k in keys:
v = sd.pop(k)
new_k = k.replace("layernorm_embedding", "layer_norm")
assert new_k not in sd
sd[new_k] = v
IGNORE_KEYS = ["START"]
@torch.no_grad()
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
model = torch.load(checkpoint_path, map_location="cpu")
sd = model["model"]
cfg = BlenderbotConfig.from_json_file(config_json_path)
m = BlenderbotForConditionalGeneration(cfg)
valid_keys = m.model.state_dict().keys()
failures = []
mapping = {}
for k, v in sd.items():
if k in IGNORE_KEYS:
continue
new_k = rename_state_dict_key(k)
if new_k not in valid_keys:
failures.append([k, new_k])
else:
mapping[new_k] = v
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
rename_layernorm_keys(sd)
m.model.load_state_dict(mapping, strict=True)
m.half()
m.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
parser.add_argument(
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
)
args = parser.parse_args()
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)

View File

@ -1,191 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import requests
import torch
# git clone https://github.com/salesforce/BLIP.git
from models.blip import blip_decoder
from models.blip_itm import blip_itm
from models.blip_vqa import blip_vqa
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import (
BertTokenizer,
BlipConfig,
BlipForConditionalGeneration,
BlipForImageTextRetrieval,
BlipForQuestionAnswering,
)
def load_demo_image(image_size, device):
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
image = transform(raw_image).unsqueeze(0).to(device)
return image
def rename_key(key):
if "visual_encoder" in key:
key = re.sub("visual_encoder*", "vision_model.encoder", key)
if "blocks" in key:
key = re.sub(r"blocks", "layers", key)
if "attn" in key:
key = re.sub(r"attn", "self_attn", key)
if "norm1" in key:
key = re.sub(r"norm1", "layer_norm1", key)
if "norm2" in key:
key = re.sub(r"norm2", "layer_norm2", key)
if "encoder.norm" in key:
key = re.sub(r"encoder.norm", "post_layernorm", key)
if "encoder.patch_embed.proj" in key:
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
if "encoder.pos_embed" in key:
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
if "encoder.cls_token" in key:
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
if "self_attn" in key:
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
return key
@torch.no_grad()
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = BlipConfig.from_pretrained(config_path)
else:
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = BlipForConditionalGeneration(config).eval()
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
pt_model = pt_model.eval()
modified_state_dict = pt_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_model.load_state_dict(modified_state_dict)
image_size = 384
image = load_demo_image(image_size=image_size, device="cpu")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
input_ids = tokenizer(["a picture of"]).input_ids
out = hf_model.generate(image, input_ids)
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
out = hf_model.generate(image)
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
if pytorch_dump_folder_path is not None:
hf_model.save_pretrained(pytorch_dump_folder_path)
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
model_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
)
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
vqa_model.eval()
modified_state_dict = vqa_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_vqa_model = BlipForQuestionAnswering(config)
hf_vqa_model.load_state_dict(modified_state_dict)
question = ["How many dogs are in this image?"]
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
answer = hf_vqa_model.generate(question_input_ids, image)
print(tokenizer.decode(answer[0]))
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
if pytorch_dump_folder_path is not None:
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
itm_model.eval()
modified_state_dict = itm_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_itm_model = BlipForImageTextRetrieval(config)
question = ["A picture of a woman with a dog sitting in a beach"]
question_input_ids = tokenizer(
question,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=35,
).input_ids
hf_itm_model.load_state_dict(modified_state_dict)
hf_itm_model.eval()
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
assert out[0].item() == 0.2110687494277954
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
if pytorch_dump_folder_path is not None:
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)

View File

@ -1,390 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert BLIP-2 checkpoints from the original repository.
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
"""
import argparse
import requests
import torch
# pip3 install salesforce-lavis
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
# to make sure we can compare both original and HF implementation in float32
from lavis.models import load_model_and_preprocess
from PIL import Image
from transformers import (
AutoTokenizer,
BertTokenizer,
Blip2Config,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Processor,
Blip2QFormerConfig,
Blip2VisionConfig,
BlipImageProcessor,
OPTConfig,
T5Config,
set_seed,
)
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
def load_demo_image():
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return image
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, model_name):
rename_keys = []
# fmt: off
# vision encoder
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
for i in range(config.vision_config.num_hidden_layers):
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
# QFormer
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
if "itm" in model_name:
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
rename_keys.append(("text_proj.weight", "text_projection.weight"))
rename_keys.append(("text_proj.bias", "text_projection.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def read_in_q_v_bias(state_dict, config):
for i in range(config.vision_config.num_hidden_layers):
# read in original q and v biases
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
# next, set bias in the state dict
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
def get_blip2_config(model_name, eos_token_id):
image_size = 364 if "coco" in model_name else 224
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
# seems like flan-T5 models don't have bos_token_id properly set?
if "opt-2.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
elif "opt-6.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
elif "t5-xl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "t5-xxl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "itm" in model_name:
text_config = {}
else:
raise ValueError("Model name not supported")
if "itm" in model_name:
config = Blip2Config(
vision_config=vision_config,
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
)
else:
config = Blip2Config(vision_config=vision_config, text_config=text_config)
return config, image_size
@torch.no_grad()
def convert_blip2_checkpoint(
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
):
"""
Copy/paste/tweak model's weights to Transformers design.
"""
if "opt" in model_name:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
elif "itm" in model_name:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
else:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
if "itm" in model_name:
eos_token_id = None
else:
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
if "itm" in model_name:
hf_model = Blip2ForImageTextRetrieval(config).eval()
else:
hf_model = Blip2ForConditionalGeneration(config).eval()
model_name_to_original = {
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
}
name, type = model_name_to_original[model_name]
# load original model
print("Loading original model...")
original_model, vis_processors, _ = load_model_and_preprocess(
name=name, model_type=type, is_eval=True, device=lavis_device
)
original_model.eval()
print("Done!")
# update state dict keys
state_dict = original_model.state_dict()
rename_keys = create_rename_keys(config, model_name)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
# some keys can be renamed efficiently
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("Qformer.bert"):
key = key.replace("Qformer.bert", "qformer")
if "attention.self" in key:
key = key.replace("self", "attention")
if "opt_proj" in key:
key = key.replace("opt_proj", "language_projection")
if "t5_proj" in key:
key = key.replace("t5_proj", "language_projection")
if key.startswith("opt"):
key = key.replace("opt", "language")
if key.startswith("t5"):
key = key.replace("t5", "language")
state_dict[key] = val
# read in qv biases
read_in_q_v_bias(state_dict, config)
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
assert len(missing_keys) == 0
if "itm" in model_name:
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
else:
assert unexpected_keys == ["qformer.embeddings.position_ids"]
image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
# create processor
image_processor = BlipImageProcessor(
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
)
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
# make sure processor creates exact same pixel values
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
original_model.to(lavis_device)
hf_model.to(hf_model_device)
if "itm" in model_name:
caption = "a large fountain spewing water into the air"
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=True,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
print("Looks ok!")
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=False,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
print("Looks ok!")
else:
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(pixel_values, input_ids, labels=labels).logits
assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])
# assert values
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")
print("Generating a caption...")
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
set_seed(42)
original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
)
outputs = hf_model.generate(
pixel_values,
input_ids,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1,
)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)
if pytorch_dump_folder_path is not None:
processor.save_pretrained(pytorch_dump_folder_path)
hf_model.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
processor.push_to_hub(f"nielsr/{model_name}")
hf_model.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = [
"blip2-opt-2.7b",
"blip2-opt-6.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
"blip2-itm-vit-g",
"blip2-itm-vit-g-coco",
]
parser.add_argument(
"--model_name",
default="blip2-opt-2.7b",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
)
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model and processor to the hub after converting",
)
# note: this script is tested on 2 GPUs, as models are compared in float32,
# which requires quite some memory. Hence loading both on a
# separate device is the easiest to compare
parser.add_argument(
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
parser.add_argument(
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
args = parser.parse_args()
convert_blip2_checkpoint(
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
)

View File

@ -1,254 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigScience BLOOM checkpoint."""
import argparse
import json
import os
import re
import torch
from transformers import BloomConfig, BloomModel
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging
logging.set_verbosity_info()
WEIGHTS_TO_AVERAGE_ENDSWITH = [
"word_embeddings_layernorm.weight",
"word_embeddings_layernorm.bias",
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
"ln_f.weight",
"ln_f.bias",
]
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
"mlp.dense_4h_to_h.weight",
"self_attention.dense.weight",
]
def layer_name_mapping(key, file):
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
# Handle first and last layers
layer_rename_map = {
"word_embeddings.weight": "word_embeddings.weight",
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
"weight": "ln_f.weight",
"bias": "ln_f.bias",
}
if key in layer_rename_map:
return layer_rename_map[key]
# Handle transformer blocks
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
layer_number -= 3
return f"h.{layer_number}." + key
def get_dtype_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def convert_bloom_checkpoint_to_pytorch(
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
):
# Construct model
if bloom_config_file == "":
config = BloomConfig()
else:
config = BloomConfig.from_json_file(bloom_config_file)
if shard_model:
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
index_dict = {"weight_map": {}, "metadata": {}}
total_size = 0
missing_keys = None
config = BloomConfig()
for j, file in enumerate(file_names):
print("Processing file: {}".format(file))
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights accross TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
torch.save(
tensors,
os.path.join(
pytorch_dump_folder_path,
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
),
)
for key in tensors.keys():
value = tensors[key]
total_size += value.numel() * get_dtype_size(value.dtype)
if key not in index_dict["weight_map"]:
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
)
config = BloomConfig()
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
index_dict["metadata"]["total_size"] = total_size
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
f.write(json_config)
else:
model = BloomModel(config)
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
missing_keys = None
for i, file in enumerate(file_names):
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors.keys():
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights accross TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
other_keys = model.load_state_dict(tensors, strict=False)
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
if missing_keys is None:
missing_keys = set(other_keys.missing_keys)
else:
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
assert not missing_keys, f"The keys {missing_keys} are missing"
# Save pytorch-model
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
if config.torch_dtype is not None:
model = model.to(config.torch_dtype)
torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--bloom_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the Megatron-LM checkpoint path.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--bloom_config_file",
default="",
type=str,
help=(
"An optional config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--shard_model",
action="store_true",
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
)
parser.add_argument(
"--pretraining_tp",
default=4,
type=int,
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
)
args = parser.parse_args()
convert_bloom_checkpoint_to_pytorch(
args.bloom_checkpoint_path,
args.bloom_config_file,
args.pytorch_dump_folder_path,
args.shard_model,
args.pretraining_tp,
)

View File

@ -1,145 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Bros checkpoints."""
import argparse
import bros # original repo
import torch
from transformers import BrosConfig, BrosModel, BrosProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_configs(model_name):
bros_config = BrosConfig.from_pretrained(model_name)
return bros_config
def remove_ignore_keys_(state_dict):
ignore_keys = [
"embeddings.bbox_sinusoid_emb.inv_freq",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(name):
if name == "embeddings.bbox_projection.weight":
name = "bbox_embeddings.bbox_projection.weight"
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
return name
def convert_state_dict(orig_state_dict, model):
# rename keys
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
orig_state_dict[rename_key(key)] = val
# remove ignore keys
remove_ignore_keys_(orig_state_dict)
return orig_state_dict
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
# load original model
original_model = bros.BrosModel.from_pretrained(model_name).eval()
# load HuggingFace Model
bros_config = get_configs(model_name)
model = BrosModel.from_pretrained(model_name, config=bros_config)
model.eval()
state_dict = original_model.state_dict()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
# verify results
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
bbox = torch.tensor(
[
[
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
]
]
)
processor = BrosProcessor.from_pretrained(model_name)
encoding = processor("His name is Rocco.", return_tensors="pt")
encoding["bbox"] = bbox
original_hidden_states = original_model(**encoding).last_hidden_state
# pixel_values = processor(image, return_tensors="pt").pixel_values
last_hidden_states = model(**encoding).last_hidden_state
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="jinho8345/bros-base-uncased",
required=False,
type=str,
help="Name of the original model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
required=False,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model and processor to the 🤗 hub.",
)
args = parser.parse_args()
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,59 +0,0 @@
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)

Some files were not shown because too many files have changed in this diff Show More