Compare commits

...

27 Commits

Author SHA1 Message Date
d79b2d981f v4.55.4 2025-08-22 14:39:20 +02:00
90792b730a Revert "Fix GPT-OSS swiglu_limit not passed in for MXFP4 #40197"
The cherry-picked commit does not match the changes nor the PR
This reverts commit e75d67ec39b4f8bc03dbeea9016fff16c2375c5a.
2025-08-22 11:21:18 +02:00
a03df6acd4 Fix GPT-OSS swiglu_limit not passed in for MXFP4 (#40197)
Add swiglu_limit = 7.0
2025-08-22 11:20:23 +02:00
170b2708cb Fixes #40262 2025-08-21 11:03:16 +02:00
7dbc054e2a v4.55.3 2025-08-18 14:46:54 +02:00
c097a43898 [bugfix] fix flash-attention2 unavailable error for Ascend NPU (#40151)
* [bugfix] fix flash-attention2 unavailable error for Ascend NPU

* remove redundant apply_rotary_emb usage

* fix ruff check error

* pad_input and unpad_input use same implementation as fa2

* rollback redundant codes

* fix ruff check error

* optimize fa2 judgement logic
2025-08-18 14:45:23 +02:00
663cbb0d04 [FA2] Fix it finally - revert fa kwargs preparation (#40161)
revert
2025-08-18 14:44:58 +02:00
c7bd5350f0 Fix fsdp for generic-task models #40191 2025-08-18 14:44:16 +02:00
e75d67ec39 Fix GPT-OSS swiglu_limit not passed in for MXFP4 #40197 2025-08-18 14:43:31 +02:00
d7f67d2006 Fix mamba caches (#40203)
fix mamba models caches inheritance
2025-08-18 14:27:04 +02:00
acf295aec3 v4.55.2 2025-08-13 20:14:33 +02:00
aaa3169aa2 qfix bad cherry-pick 2025-08-13 18:13:21 +00:00
ea2eee0bc8 v4.55.1 2025-08-13 10:33:42 +02:00
956be23fff [bugfix] Fix tensor device in Idefics2, Idefics3, and SmolVLM (#39975)
* [bugfix] ensure correct tensor device in Idefics2, Idefics3, and SmolVLM models

* to cuda
2025-08-13 10:33:17 +02:00
79a9ffc520 fix merge conlicts 2025-08-13 10:25:20 +02:00
99404c7098 Default to dequantize if cpu in device_map for mxfp4 (#39993)
* default to dq if cpu

* an other check

* style

* revert some changes
2025-08-13 10:22:01 +02:00
0d6908038c [GPT Big Code] Fix attention scaling (#40041)
* fix

* update integration tests

* fmt

* add regression test
2025-08-13 10:22:01 +02:00
b8e97fbfd2 fix: resolve triton version check compatibility on windows (#39986)
* fix: resolve triton version check compatibility on windows

* style: remove trailing space

* fix: fix typo

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-08-13 10:22:01 +02:00
586b6e693b Fix missing None default values for Gemma3n model in get_placeholder_mask (#39991) (#40024)
* Fix missing None default values for Gemma3n model in get_placeholder_mask (#39991)

* Switched definition of optional from| None to Optiona[] (Issue #39991)

---------

Co-authored-by: Laurenz Ruzicka <Laurenz.Ruzicka@ait.ac.at>
2025-08-13 10:22:01 +02:00
95ae07d11f Fix broken image inference for Fuyu model (#39915)
* fix fuyu

Signed-off-by: Isotr0py <2037008807@qq.com>

* oops

Signed-off-by: Isotr0py <2037008807@qq.com>

* run test on GPU

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* clean unused

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* revert

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* add fuyu multimodal test

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix

Signed-off-by: Isotr0py <2037008807@qq.com>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
2025-08-13 10:22:01 +02:00
0d9032ae71 Fix missing video inputs for PerceptionLM. (#39971)
* Fix missing video inputs for PerceptionLM.

* Minor fix for vanilla input image (only C,H,W, no tiles dim).

* Revert "Minor fix for vanilla input image (only C,H,W, no tiles dim)."

This reverts commit 181d87b964e59c4118035a9fd4f530c6e551ba9f.
2025-08-13 10:22:01 +02:00
1d42803aac [Idefics] fix device mismatch (#39981)
fix
2025-08-13 10:22:01 +02:00
382717e543 remove triton_kernels dep with kernels instead (#39926)
* remove dep

* style

* rm import

* fix

* style

* simplify

* style
2025-08-13 10:22:01 +02:00
cc98f42d22 Enable gpt-oss mxfp4 on older hardware (sm75+) (#39940)
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-08-13 10:22:01 +02:00
d2f7266367 Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)
* Fix MXFP4 quantizer validation to enable CPU dequantization

Move dequantize check before CUDA availability check to allow
CPU inference when quantization_config.dequantize is True.
This enables users to run MXFP4 models on CPU by automatically
converting them to BF16 format.

* Add tests for MXFP4 quantizer CPU dequantization validation

* fix: format mxfp4 test file with ruff
2025-08-13 10:22:01 +02:00
daab2db33f [CI] post-GptOss fixes for green CI (#39929) 2025-08-07 16:27:00 +02:00
06f8004e5c Release: v4.55.0 2025-08-05 18:09:15 +02:00
391 changed files with 1160 additions and 78204 deletions

View File

@ -511,6 +511,8 @@
title: GPT2
- local: model_doc/gpt_bigcode
title: GPTBigCode
- local: model_doc/gpt_oss
title: GptOss
- local: model_doc/gptsan-japanese
title: GPTSAN Japanese
- local: model_doc/gpt-sw3
@ -617,8 +619,6 @@
title: OLMoE
- local: model_doc/open-llama
title: Open-Llama
- local: model_doc/openai_moe
title: OpenAIMoe
- local: model_doc/opt
title: OPT
- local: model_doc/pegasus

View File

@ -65,6 +65,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] HqqConfig
## Mxfp4Config
[[autodoc]] Mxfp4Config
## FbgemmFp8Config
[[autodoc]] FbgemmFp8Config

View File

@ -24,11 +24,11 @@ rendered properly in your Markdown viewer.
</div>
</div>
# OpenAIMoE
# GptOss
## Overview
The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
The GptOss model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
@ -43,16 +43,16 @@ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
## OpenAIMoeConfig
## GptOssConfig
[[autodoc]] OpenAIMoeConfig
[[autodoc]] GptOssConfig
## OpenAIMoeModel
## GptOssModel
[[autodoc]] OpenAIMoeModel
[[autodoc]] GptOssModel
- forward
## OpenAIMoeForCausalLM
## GptOssForCausalLM
[[autodoc]] OpenAIMoeForCausalLM
[[autodoc]] GptOssForCausalLM
- forward

View File

@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets[audio]>=1.14.0",
# "evaluate",
# "librosa",
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate>=0.12.0",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate>=0.12.0",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -56,7 +56,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -61,7 +61,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "timm",
# "datasets",
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "timm",
# "datasets",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -69,7 +69,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -72,7 +72,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -74,7 +74,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "sentencepiece != 0.1.92",
# "protobuf",
@ -57,7 +57,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = logging.getLogger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "sentencepiece != 0.1.92",
# "protobuf",
@ -65,7 +65,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "timm",
# "datasets>=4.0",
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "albumentations >= 1.4.16",
# "timm",
# "datasets>=4.0",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets >= 2.0.0",
# "torch >= 1.3",
# "accelerate",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets >= 2.0.0",
# "torch >= 1.3",
# "accelerate",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets[audio] >= 1.12.0",
# "torch >= 1.5",
# "torchaudio",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -64,7 +64,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)

View File

@ -16,7 +16,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -16,7 +16,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.21.0",
# "sentencepiece != 0.1.92",
# "protobuf",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.21.0",
# "sentencepiece != 0.1.92",
# "protobuf",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "seqeval",
# "datasets >= 1.8.0",
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "seqeval",
# "datasets >= 1.8.0",
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -66,7 +66,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.55.4",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = logging.getLogger(__name__)

View File

@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
logger = logging.getLogger(__name__)

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.55.0.dev0")
check_min_version("4.55.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -463,7 +463,7 @@ install_requires = [
setup(
name="transformers",
version="4.55.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.55.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.55.0.dev0"
__version__ = "4.55.4"
from pathlib import Path
from typing import TYPE_CHECKING

View File

@ -677,24 +677,6 @@ class GenerationMixin(ContinuousMixin):
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
max_length_k = int(pos.max()) + 1
bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
max_length_q = int(q_len.max())
model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
@ -1816,7 +1798,8 @@ class GenerationMixin(ContinuousMixin):
if model_kwargs.get("past_key_values") is not None:
cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache):
# Support for BC tuple cache format
if isinstance(cache, tuple):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()

View File

@ -49,7 +49,7 @@ FP4_VALUES = [
# Copied from GPT_OSS repo and vllm
def quantize_to_mxfp4(w):
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
w, w_scale = swizzle_mxfp4(w, w_scale)
@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
def swizzle_mxfp4(w, w_scale):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
FP4, convert_layout, wrap_torch_tensor = (
triton_kernels_hub.tensor.FP4,
triton_kernels_hub.tensor.convert_layout,
triton_kernels_hub.tensor.wrap_torch_tensor,
)
layout = triton_kernels_hub.tensor_details.layout
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
@ -168,16 +172,20 @@ class Mxfp4GptOssExperts(nn.Module):
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
)
self.alpha = 1.702
self.limit = getattr(config, "swiglu_limit", 7.0)
self.gate_up_proj_precision_config = None
self.down_proj_precision_config = None
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
from triton_kernels.swiglu import swiglu_fn
FnSpecs, FusedActivation, matmul_ogs = (
triton_kernels_hub.matmul_ogs.FnSpecs,
triton_kernels_hub.matmul_ogs.FusedActivation,
triton_kernels_hub.matmul_ogs.matmul_ogs,
)
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
with torch.cuda.device(hidden_states.device):
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
intermediate_cache1 = matmul_ogs(
hidden_states,
@ -211,7 +219,12 @@ def routing_torch_dist(
):
import os
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
triton_kernels_hub.routing.GatherIndx,
triton_kernels_hub.routing.RoutingData,
triton_kernels_hub.routing.ScatterIndx,
triton_kernels_hub.routing.compute_expt_data_torch,
)
with torch.cuda.device(logits.device):
world_size = torch.distributed.get_world_size()
@ -274,13 +287,16 @@ def mlp_forward(self, hidden_states):
if dist.is_available() and dist.is_initialized():
routing = routing_torch_dist
else:
from triton_kernels.routing import routing
routing = triton_kernels_hub.routing.routing
routing = routing
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
with torch.cuda.device(router_logits.device):
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
return routed_out, router_logits
@ -334,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
from ..integrations.tensor_parallel import shard_and_distribute_module
model = kwargs.get("model", None)
@ -447,6 +466,11 @@ def replace_with_mxfp4_linear(
):
if quantization_config.dequantize:
return model
else:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

View File

@ -10,20 +10,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import torch
import torch.nn.functional as F
from ..utils.import_utils import is_torch_npu_available
if is_torch_npu_available():
import math
import torch_npu
from einops import rearrange, repeat
from torch_npu import npu_rotary_mul
from torch_npu import npu_fusion_attention
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
# dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def npu_flash_attn_func(
q,
k,
@ -179,11 +64,11 @@ def npu_flash_attn_func(
if not causal:
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
if not causal:
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
output = npu_fusion_attention(
q,
k,
v,
@ -251,19 +136,3 @@ def npu_flash_attn_varlen_func(
)[0]
return output
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
cos = cos.repeat(1, 2)
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
cos = cos.unsqueeze(0).unsqueeze(2)
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
sin = sin.repeat(1, 2)
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
sin = sin.unsqueeze(0).unsqueeze(2)
return npu_rotary_mul(x, cos, sin)

View File

@ -1,4 +1,4 @@
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,17 +14,15 @@
import inspect
import os
import warnings
from functools import partial
from typing import Optional, TypedDict
import torch
import torch.nn.functional as F
from transformers.utils.import_utils import is_kernels_available
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
logging,
@ -34,18 +32,139 @@ from .utils import (
logger = logging.get_logger(__name__)
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
return reshaped[indices]
# TODO Deprecate when all models have the attention interface
def flash_attn_supports_top_left_mask():
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available():
return not is_flash_attn_greater_or_equal_2_10()
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
return is_npu_fa2_top_left_aligned_causal_mask()
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
# TODO Deprecate when all models have the attention interface
def is_flash_attn_available():
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
_flash_fn = None
_flash_varlen_fn = None
_pad_fn = None
_unpad_fn = None
# function that processes kwargs, generalized to handle any supported kwarg within the function
_process_flash_kwargs_fn = None
# exceptions where hf API doesn't match the original flash attention API
_hf_api_to_flash_mapping = {
"dropout": "dropout_p",
"sliding_window": "window_size",
}
def _lazy_imports(implementation: Optional[str]):
"""
FA3-compatible unpad_input function.
Lazy loads the respective flash attention implementations.
Return:
flash_attn_func: The base flash attention function.
flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
e.g. for padding-free training.
pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
"""
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
pad_input, unpad_input = _pad_input, _unpad_input
if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
elif is_torch_npu_available():
# Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
# Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
else:
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
# Kernels fallback
else:
flash_attn_func = getattr(implementation, "flash_attn_func", None)
flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
if flash_attn_varlen_func is None or flash_attn_func is None:
raise ValueError(
f"Could not find the currently requested flash attention implementation at `{implementation}`."
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
)
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
def _lazy_define_process_function(flash_function):
"""
Depending on the version and kernel some features are not supported. Due to limitations in
`torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
within `_process_flash_attention_kwargs`.
NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
"""
global _process_flash_kwargs_fn, _hf_api_to_flash_mapping
flash_parameters = inspect.signature(flash_function).parameters
process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
supports_mapping = {}
for param in process_parameters:
fa_param = _hf_api_to_flash_mapping.get(param, param)
supports_mapping[fa_param] = fa_param in flash_parameters
return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
def lazy_import_flash_attention(implementation: Optional[str]):
"""
Lazy loading flash attention and returning the respective functions + flags back
NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
"""
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
global _process_flash_kwargs_fn
if _process_flash_kwargs_fn is None:
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
@ -69,14 +188,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
)
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
def _pad_input(hidden_states, indices, batch, seqlen):
"""
FA3-compatible pad_input function.
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
@ -89,9 +210,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
@ -125,6 +248,7 @@ def _upad_input(
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
tensors for query, key, value tensors.
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@ -138,6 +262,7 @@ def _upad_input(
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@ -190,12 +315,79 @@ def _upad_input(
)
def _prepare_from_posids(query, key, value, position_ids):
def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
"""
This function returns all the necessary kwargs to call `flash_attn_varlen_func`
extracted from position_ids. The `position_ids` can be either packed sequence or
the usual padded position ids, for example in inference time.
Arguments:
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
is_packed_sequence (`bool`, *optional*, defaults to `True`):
Whether the input position ids are a packed sequence or not.
Return:
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
# If the lengths are not equal, most probably we are in decoding stage with cache
# In that case the position ids will not always start with `0` and we need a better way to infer
# cumulative seq lengths.
if not is_packed_sequence:
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
last_position_ids = position_ids[:, -1]
q_len = (
torch.ones(position_ids.size(0), **tensor_kwargs)
if position_ids.shape[-1] == 1
else last_position_ids.add(1)
)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0)
cu_seq_lens_k = torch.cat(
[torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0
)
max_length_q = int(q_len.max())
max_length_k = int(last_position_ids.max()) + 1
else:
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seq_lens_q = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
cu_seq_lens_k = cu_seq_lens_q
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
max_length_q = max_length_q.item()
max_length_k = max_length_q
return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
def _prepare_from_posids(query, key, value, position_ids, query_length):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@ -205,6 +397,9 @@ def _prepare_from_posids(query, key, value, position_ids):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Sequence length of the input queries.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@ -219,123 +414,152 @@ def _prepare_from_posids(query, key, value, position_ids):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
kv_length = key.shape[1]
is_packed_sequence = query_length == kv_length
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=is_packed_sequence
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
warnings.warn(
"prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
"The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.",
FutureWarning,
)
return _prepare_from_posids(query, key, value, position_ids)
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
def _is_packed_sequence(position_ids, batch_size):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
"""
if position_ids is None:
return False
increasing_position_sequences = (
torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
)
return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
def fa_peft_integration_check(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
"""
if target_dtype and q.dtype == torch.float32:
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
return q, k, v
def _lazy_imports(impl: Optional[str]):
# returns funcs and pad/unpad based on impl
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
is_fa3 = is_flash_attn_3_available()
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
except ImportError as e:
if not globals().get("use_remote_fa2", None):
use_remote_fa2 = (
input(
"Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
)
.strip()
.lower()
)
globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
if globals()["use_remote_fa2"]:
if not is_kernels_available():
raise ImportError("You need to install kernels: `pip install kernels`")
from kernels import get_kernel
impl = get_kernel("kernels-community/flash-attn")
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return (
getattr(impl, "flash_attn_func", None),
getattr(impl, "flash_attn_varlen_func"),
pad_input,
unpad_input,
True,
)
else:
raise ImportError(
"Failed to import flash attention 2, please install it or use another implementation."
) from e
if impl == "flash_attention_3" or (impl is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
else:
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
return (
getattr(impl, "flash_attn_func", None),
getattr(impl, "flash_attn_varlen_func"),
pad_input,
unpad_input,
True,
)
_flash_supports_window = None
def is_flash_attn_available():
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
def flash_attn_supports_top_left_mask():
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available():
return not is_flash_attn_greater_or_equal_2_10()
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
return is_npu_fa2_top_left_aligned_causal_mask()
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
def _process_flash_attention_kwargs(
query_length: int,
key_length: int,
is_causal: bool,
dropout: float = 0.0,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: Optional[bool] = None,
s_aux: Optional[torch.Tensor] = None,
supports_mapping: Optional[dict[str, bool]] = None,
**kwargs,
):
"""
Returns a set of kwargs that are passed down to the according flash attention function based on
requested features and whether it is supported - depends on the version and kernel implementation
which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be
inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
Args:
query_length (`int`):
Length of the query states
key_length (`int`):
Length of the key states
is_causal (`bool`):
Whether we perform causal (decoder) attention or full attention.
dropout (`float`):
Attention dropout.
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
sliding_window (`int`, *optional*):
The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
use_top_left_mask (`bool`):
Deprecated behavior of older versions of flash attention requiring different masking.
softcap (`float`, *optional*):
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
s_aux (`torch.Tensor`, *optional*):
Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
Return:
flash_kwargs (`dict`):
A dict of kwargs that are requested and supported.
"""
flash_kwargs = {
"causal": is_causal and not (use_top_left_mask and query_length == 1),
"softmax_scale": softmax_scale,
}
if supports_mapping["dropout_p"]:
flash_kwargs["dropout_p"] = dropout
if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
flash_kwargs["window_size"] = (sliding_window, sliding_window)
if supports_mapping["deterministic"]:
flash_kwargs["deterministic"] = (
deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
if supports_mapping["softcap"] and softcap is not None:
flash_kwargs["softcap"] = softcap
# Only within kernel implementation atm
if supports_mapping["s_aux"] and s_aux is not None:
flash_kwargs["s_aux"] = s_aux
return flash_kwargs
def _flash_attention_forward(
@ -360,100 +584,121 @@ def _flash_attention_forward(
implementation: Optional[str] = None,
**kwargs,
):
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
globals()["_flash_fn"] = flash_fn
globals()["_flash_varlen_fn"] = flash_varlen_fn
globals()["_pad_fn"] = pad_fn
globals()["_unpad_fn"] = unpad_fn
globals()["_is_fa3"] = is_fa3
flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
globals()["_flash_supports_window"] = flash_supports_window
else:
flash_fn = globals()["_flash_fn"]
flash_varlen_fn = globals()["_flash_varlen_fn"]
pad_fn = globals()["_pad_fn"]
unpad_fn = globals()["_unpad_fn"]
is_fa3 = globals()["_is_fa3"]
flash_supports_window = globals()["_flash_supports_window"]
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
causal = is_causal and not (use_top_left_mask and query_length == 1)
use_sw = (
(_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
(Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`, *optional*):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.
"""
(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
implementation
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
if not is_fa3:
flash_kwargs["dropout_p"] = dropout
if is_flash_attn_greater_or_equal("2.4.1"):
det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = det
if softcap is not None:
flash_kwargs["softcap"] = softcap
if "s_aux" in kwargs:
flash_kwargs["s_aux"] = kwargs.get("s_aux")
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)
use_mask = position_ids is not None or all(
k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]
# Extract the flash attention kwargs that have been requested (and are supported by the implementation)
flash_kwargs = process_flash_kwargs_fn(
query_length=query_length,
key_length=key_states.size(1),
is_causal=is_causal,
dropout=dropout,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
softcap=softcap,
deterministic=deterministic,
**kwargs,
)
# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
#
# NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model.
# See #39121 for more information.
is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
is_fa_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
# Contains at least one padding token in the sequence
if attention_mask is not None:
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
)
# TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
# TODO for now this is required to work with
# https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
if "mps" in str(q.device):
cu_k = cu_k.clone()
cu_seq_lens_k = cu_seq_lens_k.clone()
out_unpad = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_q.to(torch.int32),
cu_seqlens_k=cu_k.to(torch.int32),
max_seqlen_q=mq,
max_seqlen_k=mk,
softmax_scale=softmax_scale,
causal=causal,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
elif use_mask:
out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
# Padding free, i.e. sequences flattened into one total sequence
elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
if cu_seq_lens_q is None or cu_seq_lens_k is None:
if position_ids is None:
raise ValueError(
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
)
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
query_states, key_states, value_states, position_ids
q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
query_states, key_states, value_states, position_ids, query_length=query_length
)
else:
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
mq, mk = max_length_q, max_length_k
cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
# TODO for now this is required to work with
# https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
if "mps" in str(q.device):
cu_k = cu_k.clone()
cu_seq_lens_k = cu_seq_lens_k.clone()
out = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_q.to(torch.int32),
cu_seqlens_k=cu_k.to(torch.int32),
max_seqlen_q=mq,
max_seqlen_k=mk,
softmax_scale=softmax_scale,
causal=causal,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
)
if isinstance(out, tuple):
out = out[0]
out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
else:
out = flash_fn(
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
return out[0] if isinstance(out, tuple) else out
out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
# No padding
else:
out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
if isinstance(out, tuple):
out = out[0]
return out

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from functools import partial
from typing import Optional
@ -95,7 +94,7 @@ class GradientCheckpointingLayer(nn.Module):
@auto_docstring
class GenericForSequenceClassification(ABC):
class GenericForSequenceClassification(object):
base_model_prefix = "model"
def __init__(self, config):
@ -170,7 +169,7 @@ class GenericForSequenceClassification(ABC):
@auto_docstring
class GenericForQuestionAnswering(ABC):
class GenericForQuestionAnswering(object):
base_model_prefix = "model"
def __init__(self, config):
@ -231,7 +230,7 @@ class GenericForQuestionAnswering(ABC):
@auto_docstring
class GenericForTokenClassification(ABC):
class GenericForTokenClassification(object):
base_model_prefix = "model"
def __init__(self, config):

View File

@ -74,6 +74,7 @@ from .integrations.tensor_parallel import (
)
from .loss.loss_utils import LOSS_MAPPING
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from .modeling_flash_attention_utils import lazy_import_flash_attention
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_pp_plan = None
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False
_can_record_outputs = None
@ -2482,6 +2483,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if not is_flash_attn_2_available():
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
if is_torch_npu_available():
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return True
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
@ -2740,6 +2746,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
@ -2755,7 +2762,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return attn_implementation
else:
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
# preload flash attention here to allow compile with fullgraph
if applicable_attn_implementation.startswith("flash_attention"):
lazy_import_flash_attention(applicable_attn_implementation)
return attn_implementation
def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
requested_attention = "sdpa" if _requested_attention is None else _requested_attention

View File

@ -1,269 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
from typing import Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
Aimv2Config,
Aimv2Model,
Aimv2VisionConfig,
Aimv2VisionModel,
AutoImageProcessor,
AutoProcessor,
)
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
# Embeddings
r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
# Encoder Layers
r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
# Final Norm
r"trunk.post_trunk_norm": r"rms_norm",
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision Embeddings
r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
# Vision Encoder Layers
r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
r"image_projector": r"visual_projection",
# Vision Head
r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
r"image_encoder.head.k": r"vision_model.head.k_proj",
r"image_encoder.head.v": r"vision_model.head.v_proj",
r"image_encoder.head.linear": r"vision_model.head.output_proj",
# Text Embeddings
r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
# Text Encoder Layers
r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
# Text Normalization Layers
r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
r"text_projector": r"text_projection",
r"log_logit_scale": r"logit_scale",
}
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
# Download only the model.safetensors file
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["model.safetensors"],
)
original_state_dict = {}
safetensor_path = f"{directory_path}/model.safetensors"
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
"""Converts state dict keys from the old format to the new format."""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
def split_qkv_tensor(key, tensor):
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
new_keys = ["q_proj", "k_proj", "v_proj"]
split_size = tensor.shape[0] // 3
split_tensors = torch.split(tensor, split_size, dim=0)
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
def get_model_config_mapping(model_id: str):
"""Determines the correct model, config, and key mappings based on the checkpoint name."""
if model_id == "apple/aimv2-large-patch14-224-lit":
return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
else:
return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
def write_model(
hf_repo_id: str,
output_dir: str,
safe_serialization: bool = True,
):
"""
Converts a model checkpoint to Hugging Face format and saves it.
Args:
hf_repo_id (str): The Hugging Face repo ID to load from.
output_dir (str): The directory to save the converted model.
safe_serialization (bool): Whether to use safe serialization.
Returns:
model: The reloaded Hugging Face model.
"""
os.makedirs(output_dir, exist_ok=True)
# Get the appropriate model, config, and key mapping
model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
# Load config and original state dict
config = config_class.from_pretrained(hf_repo_id)
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False
if hf_repo_id == "apple/aimv2-large-patch14-native":
config.is_native = True
original_state_dict = load_original_state_dict(hf_repo_id)
print("Converting model...")
state_dict = {}
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
all_keys = list(original_state_dict.keys())
for key in all_keys:
value = original_state_dict[key]
new_key = result.pop(key)
if "qkv" in new_key:
qkv_state_dict = split_qkv_tensor(new_key, value)
state_dict.update(qkv_state_dict)
else:
state_dict[new_key] = value
# Check if position embeddings exist before squeezing
if new_key.endswith("position_embedding.weight"):
state_dict[new_key] = value.squeeze(0)
print(f"Loading the checkpoint in a {model_class.__name__}.")
model = model_class(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
model = model_class.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")
return model
def write_image_processor(hf_repo_id: str, output_dir: str):
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
else:
image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
image_processor.save_pretrained(output_dir)
return image_processor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_repo_id",
default="apple/aimv2-large-patch14-224",
help="Location of official weights from apple on HF",
)
parser.add_argument(
"--output_dir",
default="aimv2_model",
help="Location to write the converted model and processor",
)
parser.add_argument(
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--push_to_hub",
action=argparse.BooleanOptionalAction,
help="Whether or not to push the converted model to the huggingface hub.",
)
parser.add_argument(
"--hub_repo_id",
default=None,
help="Huggingface hub repo to write the converted model and processor",
)
args = parser.parse_args()
model = write_model(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
image_processor = write_image_processor(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
)
if args.push_to_hub:
print("Pushing to hub...")
model.push_to_hub(args.hub_repo_id)
image_processor.push_to_hub(args.hub_repo_id)
if __name__ == "__main__":
main()

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
import argparse
import torch
from ...utils import logging
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = AlbertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--albert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)

View File

@ -1,389 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALIGN checkpoints from the original repository."""
import argparse
import os
import align
import numpy as np
import requests
import tensorflow as tf
import torch
from PIL import Image
from tokenizer import Tokenizer
from transformers import (
AlignConfig,
AlignModel,
AlignProcessor,
BertConfig,
BertTokenizer,
EfficientNetConfig,
EfficientNetImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def preprocess(image):
image = tf.image.resize(image, (346, 346))
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
return image
def get_align_config():
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
vision_config.image_size = 289
vision_config.hidden_dim = 640
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
vision_config.depthwise_padding = []
text_config = BertConfig()
config = AlignConfig.from_text_vision_configs(
text_config=text_config, vision_config=vision_config, projection_dim=640
)
return config
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def get_processor():
image_processor = EfficientNetImageProcessor(
do_center_crop=True,
rescale_factor=1 / 127.5,
rescale_offset=True,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokenizer.model_max_length = 64
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
return processor
# here we list all keys to be renamed (original name on the left, our name on the right)
def rename_keys(original_param_names):
# EfficientNet image encoder
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
block_names = list(set(block_names))
block_names = sorted(block_names)
num_blocks = len(block_names)
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
rename_keys = []
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
for b in block_names:
hf_b = block_name_mapping[b]
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
rename_keys.append(
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
)
rename_keys.append(
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
)
rename_keys.append(
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
)
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
rename_keys.append(
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
)
rename_keys.append(
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
)
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
rename_keys.append(
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
)
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
rename_keys.append(
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
)
rename_keys.append(
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
)
key_mapping = {}
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = "vision_model." + item[1]
# BERT text encoder
rename_keys = []
old = "tf_bert_model/bert"
new = "text_model"
for i in range(12):
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.query.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
f"{new}.encoder.layer.{i}.attention.self.query.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.key.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
f"{new}.encoder.layer.{i}.attention.self.key.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.value.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
f"{new}.encoder.layer.{i}.attention.self.value.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
)
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
)
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
rename_keys.append(
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
)
rename_keys.append(
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
)
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("temperature:0", "temperature"))
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = item[1]
return key_mapping
def replace_params(hf_params, tf_params, key_mapping):
list(hf_params.keys())
for key, value in tf_params.items():
if key not in key_mapping:
continue
hf_key = key_mapping[key]
if "_conv" in key and "kernel" in key:
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
elif "embeddings" in key:
new_hf_value = torch.from_numpy(value)
elif "depthwise_kernel" in key:
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
elif "kernel" in key:
new_hf_value = torch.from_numpy(np.transpose(value))
elif "temperature" in key:
new_hf_value = value
elif "bn/gamma" in key or "bn/beta" in key:
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
else:
new_hf_value = torch.from_numpy(value)
# Replace HF parameters with original TF model parameters
hf_params[hf_key].copy_(new_hf_value)
@torch.no_grad()
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
"""
Copy/paste/tweak model's weights to our ALIGN structure.
"""
# Load original model
seq_length = 64
tok = Tokenizer(seq_length)
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
original_model.compile()
original_model.load_weights(checkpoint_path)
tf_params = original_model.trainable_variables
tf_non_train_params = original_model.non_trainable_variables
tf_params = {param.name: param.numpy() for param in tf_params}
for param in tf_non_train_params:
tf_params[param.name] = param.numpy()
tf_param_names = list(tf_params.keys())
# Load HuggingFace model
config = get_align_config()
hf_model = AlignModel(config).eval()
hf_params = hf_model.state_dict()
# Create src-to-dst parameter name mapping dictionary
print("Converting parameters...")
key_mapping = rename_keys(tf_param_names)
replace_params(hf_params, tf_params, key_mapping)
# Initialize processor
processor = get_processor()
inputs = processor(
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
)
# HF model inference
hf_model.eval()
with torch.no_grad():
outputs = hf_model(**inputs)
hf_image_features = outputs.image_embeds.detach().numpy()
hf_text_features = outputs.text_embeds.detach().numpy()
# Original model inference
original_model.trainable = False
tf_image_processor = EfficientNetImageProcessor(
do_center_crop=True,
do_rescale=False,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
text = tok(tf.constant(["A picture of a cat"]))
image_features = original_model.image_encoder(image, training=False)
text_features = original_model.text_encoder(text, training=False)
image_features = tf.nn.l2_normalize(image_features, axis=-1)
text_features = tf.nn.l2_normalize(text_features, axis=-1)
# Check whether original and HF model outputs match -> np.allclose
if not np.allclose(image_features, hf_image_features, atol=1e-3):
raise ValueError("The predicted image features are not the same.")
if not np.allclose(text_features, hf_text_features, atol=1e-3):
raise ValueError("The predicted text features are not the same.")
print("Model outputs match!")
if save_model:
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
# Save converted model and image processor
hf_model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Push model and image processor to hub
print("Pushing converted ALIGN to the hub...")
processor.push_to_hub("align-base")
hf_model.push_to_hub("align-base")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_path",
default="./weights/model-weights",
type=str,
help="Path to the pretrained TF ALIGN checkpoint.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="hf_model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
args = parser.parse_args()
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)

View File

@ -1,162 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
AddedToken,
AriaForConditionalGeneration,
AriaProcessor,
AutoConfig,
AutoTokenizer,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
Example for creating the old state dict file with Python:
import torch
from aria.model.language_model.aria_llama import AriaTextForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"vision_tower.vision_model": "vision_tower",
"ln_ffn": "layer_norm",
"ffn": "feed_forward",
"ln_kv": "layer_norm_kv",
}
def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
return new_state_dict
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
extra_special_tokens={
"image_token": "<|img|>",
"pad_token": "<pad>",
},
)
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
processor = AriaProcessor.from_pretrained(
text_model_id,
tokenizer=tokenizer,
)
config = AutoConfig.from_pretrained(text_model_id)
config.vision_config.hidden_size = 1152
config.vision_config.attention_heads = 16
config.pad_token_id = 2
config.image_token_id = 9
config.intermediate_size = config.moe_intermediate_size
config.auto_map = {
"AutoConfig": "modeling_aria.AriaConfig",
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
}
with torch.device("meta"):
model = AriaForConditionalGeneration(config)
state_dict = load_original_state_dict(old_state_dict_id)
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)
# print("Saving models")
# model.save_pretrained("local_aria", safe_serialization=False)
# processor.save_pretrained("local_aria")
print("Pushing to hub")
model.push_to_hub(output_hub_path, create_pr=True)
processor.push_to_hub(output_hub_path, create_pr=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
default="rhymes-ai/Aria",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="rhymes-ai/Aria",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="rhymes-ai/Aria",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="rhymes-ai/Aria",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
if __name__ == "__main__":
main()

View File

@ -1,279 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
import argparse
import json
from pathlib import Path
import torch
import torchaudio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_audio_spectrogram_transformer_config(model_name):
config = ASTConfig()
if "10-10" in model_name:
pass
elif "speech-commands" in model_name:
config.max_length = 128
elif "12-12" in model_name:
config.time_stride = 12
config.frequency_stride = 12
elif "14-14" in model_name:
config.time_stride = 14
config.frequency_stride = 14
elif "16-16" in model_name:
config.time_stride = 16
config.frequency_stride = 16
else:
raise ValueError("Model not supported")
repo_id = "huggingface/label-files"
if "speech-commands" in model_name:
config.num_labels = 35
filename = "speech-commands-v2-id2label.json"
else:
config.num_labels = 527
filename = "audioset-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def rename_key(name):
if "module.v" in name:
name = name.replace("module.v", "audio_spectrogram_transformer")
if "cls_token" in name:
name = name.replace("cls_token", "embeddings.cls_token")
if "dist_token" in name:
name = name.replace("dist_token", "embeddings.distillation_token")
if "pos_embed" in name:
name = name.replace("pos_embed", "embeddings.position_embeddings")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
# transformer blocks
if "blocks" in name:
name = name.replace("blocks", "encoder.layer")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
# final layernorm
if "audio_spectrogram_transformer.norm" in name:
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
# classifier head
if "module.mlp_head.0" in name:
name = name.replace("module.mlp_head.0", "classifier.layernorm")
if "module.mlp_head.1" in name:
name = name.replace("module.mlp_head.1", "classifier.dense")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy():
val = orig_state_dict.pop(key)
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[3])
dim = config.hidden_size
if "weight" in key:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
] = val[:dim, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
] = val[dim : dim * 2, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
] = val[:dim]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
] = val[dim : dim * 2]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def remove_keys(state_dict):
ignore_keys = [
"module.v.head.weight",
"module.v.head.bias",
"module.v.head_dist.weight",
"module.v.head_dist.bias",
]
for k in ignore_keys:
state_dict.pop(k, None)
@torch.no_grad()
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
"""
config = get_audio_spectrogram_transformer_config(model_name)
model_name_to_url = {
"ast-finetuned-audioset-10-10-0.4593": (
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.450": (
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448": (
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448-v2": (
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
),
"ast-finetuned-audioset-12-12-0.447": (
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
),
"ast-finetuned-audioset-14-14-0.443": (
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
),
"ast-finetuned-audioset-16-16-0.442": (
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
),
"ast-finetuned-speech-commands-v2": (
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
),
}
# load original state_dict
checkpoint_url = model_name_to_url[model_name]
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
# remove some keys
remove_keys(state_dict)
# rename some keys
new_state_dict = convert_state_dict(state_dict, config)
# load 🤗 model
model = ASTForAudioClassification(config)
model.eval()
model.load_state_dict(new_state_dict)
# verify outputs on dummy input
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
max_length = 1024 if "speech-commands" not in model_name else 128
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
if "speech-commands" in model_name:
# TODO: Convert dataset to Parquet
dataset = load_dataset("google/speech_commands", "v0.02", split="validation")
waveform = dataset[0]["audio"]["array"]
else:
filepath = hf_hub_download(
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
filename="sample_audio.flac",
repo_type="dataset",
)
waveform, _ = torchaudio.load(filepath)
waveform = waveform.squeeze().numpy()
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
if model_name == "ast-finetuned-audioset-10-10-0.4593":
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
elif model_name == "ast-finetuned-audioset-10-10-0.450":
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
elif model_name == "ast-finetuned-audioset-10-10-0.448":
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
elif model_name == "ast-finetuned-audioset-12-12-0.447":
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
elif model_name == "ast-finetuned-audioset-14-14-0.443":
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
elif model_name == "ast-finetuned-audioset-16-16-0.442":
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
elif model_name == "ast-finetuned-speech-commands-v2":
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
else:
raise ValueError("Unknown model name")
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
raise ValueError("Logits don't match")
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and feature extractor to the hub...")
model.push_to_hub(f"MIT/{model_name}")
feature_extractor.push_to_hub(f"MIT/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ast-finetuned-audioset-10-10-0.4593",
type=str,
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,273 +0,0 @@
# coding=utf-8
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse
import json
import os
import re
from os import path
from typing import Optional, Union
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from transformers import AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from .configuration_bamba import BambaConfig
def convert_state_dict_from_mamba_ssm(original_sd: dict) -> dict[str, torch.Tensor]:
state_dict = {}
for orig_k, param in original_sd.items():
k = orig_k.replace("backbone", "model")
# for embeddings
k = k.replace("embedding", "embed_tokens")
# for mixer
k = k.replace("mixer", "mamba")
# for final layernorm
k = k.replace("norm_f", "final_layernorm")
# for block layernorm
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
# for mlp
k = k.replace("mlp.fc2", "feed_forward.down_proj")
if "mlp.fc1" in k:
param, param2 = torch.chunk(param, 2, dim=0)
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
state_dict[k2] = param2
k = k.replace("mlp.fc1", "feed_forward.up_proj")
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
):
# then this must be a mamba
pass
else:
# for attn
# - because mixer was replaced to mamba above
k = k.replace("mamba.out_proj", "self_attn.o_proj")
if "mamba.in_proj" in k:
m, n = param.shape
d = (m - n) // 2
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
state_dict[k2] = param2
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
state_dict[k2] = param3
k = k.replace("mamba.in_proj", "self_attn.q_proj")
state_dict[k] = param
return state_dict
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_ssm_config_to_hf_config(
config_ssm: dict,
**kwargs,
) -> BambaConfig:
"""Convert a config from mamba_ssm to a BambaConfig from here."""
hf_config: BambaConfig = BambaConfig(**kwargs)
hf_config.architectures = ["BambaForCausalLM"]
# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm["d_model"]
hf_config.intermediate_size = config_ssm["d_intermediate"]
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
hf_config.num_hidden_layers = config_ssm["n_layer"]
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
# currently this script assumes config_ssm belongs to v2
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
raise ValueError("Conversion script only supports Mamba2")
# Set attention values
attn_cfg = config_ssm.get("attn_cfg")
if attn_cfg:
assert attn_cfg["causal"], "Only support non-causal attention."
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
hf_config.num_attention_heads = attn_cfg["num_heads"]
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
attention_layer_indices = config_ssm.get("attn_layer_idx")
if attention_layer_indices:
hf_config.attn_layer_indices = attention_layer_indices
# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size
return hf_config
def save_single_safetensor(
state_dict: dict,
save_directory: str,
metadata: dict,
):
save_file(
state_dict,
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
metadata,
)
def save_sharded_safetensors(
state_dict: dict,
save_directory: str,
metadata: dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
mamba_ssm_checkpoint_path: str,
precision: str,
output_dir: str,
tokenizer_path: Optional[str] = None,
save_model: Union[bool, str] = True,
) -> None:
# load tokenizer if provided, this will be used to set the
# token_ids in the config file
token_ids = {}
if tokenizer_path:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
for key in [
"bos_token_id",
"eos_token_id",
"pad_token_id",
]:
id = getattr(tokenizer, key, None)
if id:
token_ids[key] = id
# there are some configs unsettable by mamba_ssn config, so
# if there are changes from the defaults, have to pass them into
# the function
unsettables = {
"mamba_d_head": 64,
"mamba_d_state": 128,
"mamba_n_groups": 1,
"rms_norm_eps": 1e-5,
}
# Load and save config based on name
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
# convert the config
hf_config = convert_ssm_config_to_hf_config(
config_ssm=config,
**token_ids,
**unsettables,
)
hf_config.save_pretrained(output_dir)
# Load state dict of the original model and transfer to hf model
state_dict = torch.load(
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)
# FIXME: allow other parameters to pass in
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
# Save new model to pytorch_dump_path
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_file_fn = None
if isinstance(save_model, bool) and save_model:
save_file_fn = save_single_safetensor
elif isinstance(save_model, str) and save_model == "sharded":
save_file_fn = save_sharded_safetensors
if save_file_fn:
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba_ssm_checkpoint_directory",
type=str,
required=True,
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
required=True,
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a the tokenizer file.",
)
args = parser.parse_args()
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
args.mamba_ssm_checkpoint_directory,
args.precision,
args.output_dir,
save_model="sharded",
)

View File

@ -31,7 +31,7 @@ from torch import nn
from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, DynamicLayer
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
@ -85,7 +85,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False):
seq_idx: torch.IntTensor
class HybridMambaAttentionDynamicCache(Cache):
class HybridMambaAttentionDynamicCache:
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
@ -104,7 +104,6 @@ class HybridMambaAttentionDynamicCache(Cache):
is_compileable = False
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
super().__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv

View File

@ -42,7 +42,6 @@ from transformers.models.mamba2.modeling_mamba2 import (
segment_sum,
)
from ...cache_utils import DynamicLayer
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
@ -114,7 +113,6 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache):
"""
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer)
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
conv_kernel_size = config.mamba_d_conv

View File

@ -1,263 +0,0 @@
"""Convert Bark checkpoint."""
import argparse
import os
from pathlib import Path
import torch
from bark.generation import _load_model as _bark_load_model
from huggingface_hub import hf_hub_download
from transformers import EncodecConfig, EncodecModel, set_seed
from transformers.models.bark.configuration_bark import (
BarkCoarseConfig,
BarkConfig,
BarkFineConfig,
BarkSemanticConfig,
)
from transformers.models.bark.generation_configuration_bark import (
BarkCoarseGenerationConfig,
BarkFineGenerationConfig,
BarkGenerationConfig,
BarkSemanticGenerationConfig,
)
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
set_seed(770)
new_layer_name_dict = {
"c_attn": "att_proj",
"c_proj": "out_proj",
"c_fc": "in_proj",
"transformer.": "",
"h.": "layers.",
"ln_1": "layernorm_1",
"ln_2": "layernorm_2",
"ln_f": "layernorm_final",
"wpe": "position_embeds_layer",
"wte": "input_embeds_layer",
}
REMOTE_MODEL_PATHS = {
"text_small": {
"repo_id": "suno/bark",
"file_name": "text.pt",
},
"coarse_small": {
"repo_id": "suno/bark",
"file_name": "coarse.pt",
},
"fine_small": {
"repo_id": "suno/bark",
"file_name": "fine.pt",
},
"text": {
"repo_id": "suno/bark",
"file_name": "text_2.pt",
},
"coarse": {
"repo_id": "suno/bark",
"file_name": "coarse_2.pt",
},
"fine": {
"repo_id": "suno/bark",
"file_name": "fine_2.pt",
},
}
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
def _get_ckpt_path(model_type, use_small=False):
key = model_type
if use_small:
key += "_small"
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
def _download(from_hf_path, file_name):
os.makedirs(CACHE_DIR, exist_ok=True)
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if model_type == "text":
ModelClass = BarkSemanticModel
ConfigClass = BarkSemanticConfig
GenerationConfigClass = BarkSemanticGenerationConfig
elif model_type == "coarse":
ModelClass = BarkCoarseModel
ConfigClass = BarkCoarseConfig
GenerationConfigClass = BarkCoarseGenerationConfig
elif model_type == "fine":
ModelClass = BarkFineModel
ConfigClass = BarkFineConfig
GenerationConfigClass = BarkFineGenerationConfig
else:
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"])
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
# convert Bark model arguments to HF Bark model arguments
model_args["num_heads"] = model_args.pop("n_head")
model_args["hidden_size"] = model_args.pop("n_embd")
model_args["num_layers"] = model_args.pop("n_layer")
model_config = ConfigClass(**checkpoint["model_args"])
model = ModelClass(config=model_config)
model_generation_config = GenerationConfigClass()
model.generation_config = model_generation_config
state_dict = checkpoint["model"]
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k in state_dict:
if k.startswith(unwanted_prefix):
# replace part of the key with corresponding layer name in HF implementation
new_k = k[len(unwanted_prefix) :]
for old_layer_name, new_layer_name in new_layer_name_dict.items():
new_k = new_k.replace(old_layer_name, new_layer_name)
state_dict[new_k] = state_dict.pop(k)
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.num_parameters(exclude_embeddings=True)
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
return model
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
device = "cpu" # do conversion on cpu
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
# load bark initial model
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
if model_type == "text":
bark_model = bark_model["model"]
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
raise ValueError("initial and new models don't have the same number of parameters")
# check if same output as the bark model
batch_size = 5
sequence_length = 10
if model_type in ["text", "coarse"]:
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
output_old_model = bark_model(vec)[0]
output_new_model_total = model(vec)
# take last logits
output_new_model = output_new_model_total.logits[:, [-1], :]
else:
prediction_codebook_channel = 3
n_codes_total = 8
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
output_new_model_total = model(prediction_codebook_channel, vec)
output_old_model = bark_model(prediction_codebook_channel, vec)
output_new_model = output_new_model_total.logits
# output difference should come from the difference of self-attention implementation design
if output_new_model.shape != output_old_model.shape:
raise ValueError("initial and new outputs don't have the same shape")
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
raise ValueError("initial and new outputs are not equal")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
def load_whole_bark_model(
semantic_path,
coarse_path,
fine_path,
append_text,
hub_path,
folder_path,
):
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
semantic = BarkSemanticModel.from_pretrained(semantic_path)
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
bark_config = BarkConfig.from_sub_model_configs(
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
)
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
)
bark = BarkModel(bark_config)
bark.semantic = semantic
bark.coarse_acoustics = coarseAcoustic
bark.fine_acoustics = fineAcoustic
bark.codec_model = codec
bark.generation_config = bark_generation_config
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
args = parser.parse_args()
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)

View File

@ -1,156 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BART checkpoint."""
import argparse
import os
from pathlib import Path
import fairseq
import torch
from packaging import version
from torch import nn
from transformers import (
BartConfig,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_TEXT = " Hello world! cécé herlolip"
mnli_rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
if not os.path.exists(checkpoint_path):
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)
bart.model.upgrade_state_dict(bart.model.state_dict())
if hf_checkpoint_name is None:
hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
if not torch.eq(tokens, tokens2).all():
raise ValueError(
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
)
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in mnli_rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "facebook/bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Check results
if fairseq_output.shape != new_model_outputs.shape:
raise ValueError(
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
)
if (fairseq_output != new_model_outputs).any().item():
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
)
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)

View File

@ -1,373 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BEiT checkpoints from the unilm repository."""
import argparse
import json
from pathlib import Path
import requests
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
BeitConfig,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitImageProcessor,
)
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
)
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
]
)
if has_lm_head:
# mask token + shared relative position bias + layernorm
rename_keys.extend(
[
("mask_token", "beit.embeddings.mask_token"),
(
"rel_pos_bias.relative_position_bias_table",
"beit.encoder.relative_position_bias.relative_position_bias_table",
),
(
"rel_pos_bias.relative_position_index",
"beit.encoder.relative_position_bias.relative_position_index",
),
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else:
# layernorm + classification head
rename_keys.extend(
[
("fc_norm.weight", "beit.pooler.layernorm.weight"),
("fc_norm.bias", "beit.pooler.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers):
prefix = "backbone." if is_semantic else ""
# queries, keys and values
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
# relative_position bias table + index
if not has_lm_head:
# each layer has its own relative position bias
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BEiT structure.
"""
# define default BEiT configuration
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
# masked image modeling
config.use_shared_relative_position_bias = True
config.use_mask_token = True
has_lm_head = True
elif checkpoint_url[-9:-4] == "ft22k":
# intermediate fine-tuning on ImageNet-22k
config.use_relative_position_bias = True
config.num_labels = 21841
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
# this dataset contains 21843 labels but the model only has 21841
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
del id2label[9205]
del id2label[15027]
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
elif checkpoint_url[-8:-4] == "to1k":
# fine-tuning on ImageNet-1k
config.use_relative_position_bias = True
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "384" in checkpoint_url:
config.image_size = 384
if "512" in checkpoint_url:
config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
# size of the architecture
if "base" in checkpoint_url:
pass
elif "large" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val
# load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else:
model = BeitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# Check outputs on an image
if is_semantic:
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
else:
image_processor = BeitImageProcessor(
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
)
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
logits = outputs.logits
# verify logits
expected_shape = torch.Size([1, 1000])
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
expected_class_idx = 2397
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
expected_class_idx = 2396
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
expected_class_idx = 285
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
expected_class_idx = 281
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else:
raise ValueError("Can't verify logits as model is not supported")
if logits.shape != expected_shape:
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
if not has_lm_head:
if is_semantic:
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
else:
print("Predicted class idx:", logits.argmax(-1).item())
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
if logits.argmax(-1).item() != expected_class_idx:
raise ValueError("Predicted class index not as expected")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
type=str,
help="URL to the original PyTorch checkpoint (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

View File

@ -1,246 +0,0 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
weight names to the original names, so the model can be imported with Huggingface/transformer.
You may adapt this script to include classification/MLM/NSP/etc. heads.
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
Models trained with never versions are not compatible with this script.
"""
import argparse
import os
import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info(f"Loading TF weight {name} with shape {shape}")
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
" heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError(f"Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
f" {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
# Instantiate model
logger.info(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertModel(config)
# Load weights from checkpoint
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
# Save pytorch-model
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,112 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import argparse
import os
import numpy as np
import tensorflow as tf
import torch
from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
"""
Args:
model: BertModel Pytorch model instance to be converted
ckpt_dir: Tensorflow model directory
model_name: model name
Currently supported HF models:
- Y BertModel
- N BertForMaskedLM
- N BertForPreTraining
- N BertForMultipleChoice
- N BertForNextSentencePrediction
- N BertForSequenceClassification
- N BertForQuestionAnswering
"""
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
state_dict = model.state_dict()
def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return f"bert/{name}"
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var
tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
if any(x in var_name for x in tensors_to_transpose):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
tf_weight = session.run(tf_var)
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
parser.add_argument(
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
)
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path, weights_only=True),
cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__":
main()

View File

@ -1,188 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
"""
import argparse
import tensorflow as tf
import torch
from transformers import BertConfig, BertForMaskedLM
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertPooler,
BertSelfAttention,
BertSelfOutput,
)
from transformers.utils import logging
logging.set_verbosity_info()
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
def get_masked_lm_array(name: str):
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_array(name: str):
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_layer_array(layer_index: int, name: str):
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape):
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
array = array.reshape(original_shape)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
print(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertForMaskedLM(config)
# Layers
for layer_index in range(0, config.num_hidden_layers):
layer: BertLayer = model.bert.encoder.layer[layer_index]
# Self-attention
self_attn: BertSelfAttention = layer.attention.self
self_attn.query.weight.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
)
self_attn.query.bias.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
)
self_attn.key.weight.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
)
self_attn.key.bias.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
)
self_attn.value.weight.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
)
self_attn.value.bias.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
)
# Self-attention Output
self_output: BertSelfOutput = layer.attention.output
self_output.dense.weight.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
)
self_output.dense.bias.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
)
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
# Intermediate
intermediate: BertIntermediate = layer.intermediate
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
# Output
bert_output: BertOutput = layer.output
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
# Embeddings
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
# LM Head
lm_head = model.cls.predictions.transform
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
# Pooling
model.bert.pooler = BertPooler(config=config)
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
# Export final model
model.save_pretrained(pytorch_dump_path)
# Integration test - should load without any errors ;)
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
print(new_model.eval())
print("Model conversion was done successfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,69 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigBird checkpoint."""
import argparse
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
# Initialise PyTorch model
config = BigBirdConfig.from_json_file(big_bird_config_file)
print(f"Building PyTorch model from configuration: {config}")
if is_trivia_qa:
model = BigBirdForQuestionAnswering(config)
else:
model = BigBirdForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--big_bird_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
)

View File

@ -1,169 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
INIT_COMMON = [
# tf -> hf
("/", "."),
("layer_", "layers."),
("kernel", "weight"),
("beta", "bias"),
("gamma", "weight"),
("pegasus", "model"),
]
END_COMMON = [
(".output.dense", ".fc2"),
("intermediate.LayerNorm", "final_layer_norm"),
("intermediate.dense", "fc1"),
]
DECODER_PATTERNS = (
INIT_COMMON
+ [
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.out_proj"),
("attention.self", "self_attn"),
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
("attention.encdec_output.dense", "encoder_attn.out_proj"),
("attention.encdec", "encoder_attn"),
("key", "k_proj"),
("value", "v_proj"),
("query", "q_proj"),
("decoder.LayerNorm", "decoder.layernorm_embedding"),
]
+ END_COMMON
)
REMAINING_PATTERNS = (
INIT_COMMON
+ [
("embeddings.word_embeddings", "shared.weight"),
("embeddings.position_embeddings", "embed_positions.weight"),
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.output"),
("attention.self", "self_attn.self"),
("encoder.LayerNorm", "encoder.layernorm_embedding"),
]
+ END_COMMON
)
KEYS_TO_IGNORE = [
"encdec/key/bias",
"encdec/query/bias",
"encdec/value/bias",
"self/key/bias",
"self/query/bias",
"self/value/bias",
"encdec_output/dense/bias",
"attention/output/dense/bias",
]
def rename_state_dict_key(k, patterns):
for tf_name, hf_name in patterns:
k = k.replace(tf_name, hf_name)
return k
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
cfg = BigBirdPegasusConfig(**config_update)
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
state_dict = torch_model.state_dict()
mapping = {}
# separating decoder weights
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = DECODER_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(i in k for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = REMAINING_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(i in k for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
if k != "pegasus/embeddings/position_embeddings":
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
missing, extra = torch_model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k
for k in missing
if k
not in [
"final_logits_bias",
"model.encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"lm_head.weight",
]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return torch_model
def get_tf_weights_as_numpy(path) -> dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any(pat in name for pat in ignore_name)
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
tf_weights = get_tf_weights_as_numpy(ckpt_path)
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
config_update = {}
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)

View File

@ -1,292 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import shutil
import torch
from transformers import BioGptConfig, BioGptForCausalLM
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from transformers.utils import WEIGHTS_NAME, logging
logging.set_verbosity_warning()
json_indent = 2
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def _load_meta(self, lines):
return 0
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
f"Duplicate word found when loading Dictionary: '{word}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file."
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
def rewrite_dict_keys(d):
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
keep_keys = "<s> <pad> </s> <unk>".split()
# restore the special tokens
for k in keep_keys:
del d2[f"{k}</w>"]
d2[k] = d[k] # restore
return d2
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
# prep
if not os.path.exists(biogpt_checkpoint_path):
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
print(f"Writing results to {pytorch_dump_folder_path}")
# handle various types of models
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
if not os.path.isfile(checkpoint_file):
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
chkpt = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
args = chkpt["cfg"]["model"]
# dicts
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
if not os.path.isfile(dict_file):
raise ValueError(f"path to the file {dict_file} does not exist!")
src_dict = Dictionary.load(dict_file)
src_vocab = rewrite_dict_keys(src_dict.indices)
src_vocab_size = len(src_vocab)
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
# merges_file (bpecodes)
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
if not os.path.isfile(bpecodes_file):
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
shutil.copyfile(bpecodes_file, merges_file)
# model config
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
model_conf = {
"activation_dropout": args["activation_dropout"],
"architectures": ["BioGptForCausalLM"],
"attention_probs_dropout_prob": args["attention_dropout"],
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_act": args["activation_fn"],
"hidden_dropout_prob": args["dropout"],
"hidden_size": args["decoder_embed_dim"],
"initializer_range": 0.02,
"intermediate_size": args["decoder_ffn_embed_dim"],
"layer_norm_eps": 1e-12,
"layerdrop": args["decoder_layerdrop"],
"max_position_embeddings": args["max_target_positions"],
"model_type": "biogpt",
"num_attention_heads": args["decoder_attention_heads"],
"num_hidden_layers": args["decoder_layers"],
"pad_token_id": 1,
"scale_embedding": not args["no_scale_embedding"],
"tie_word_embeddings": args["share_decoder_input_output_embed"],
"vocab_size": src_vocab_size,
}
# good hparam defaults to start with
print(f"Generating {biogpt_model_config_file}")
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
# tokenizer config
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
tokenizer_conf = {
"bos_token": "<s>",
"eos_token": "</s>",
"model_max_length": 1024,
"pad_token": "<pad>",
"special_tokens_map_file": None,
"tokenizer_class": "BioGptTokenizer",
"unk_token": "<unk>",
}
print(f"Generating {biogpt_tokenizer_config_file}")
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
# model
model_state_dict = chkpt["model"]
# remove unneeded keys
ignore_keys = [
"decoder.version",
]
for k in ignore_keys:
model_state_dict.pop(k, None)
layer_names = list(model_state_dict.keys())
for layer_name in layer_names:
if layer_name.endswith("output_projection.weight"):
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
else:
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
model_new = BioGptForCausalLM(config)
# check that it loads ok
model_new.load_state_dict(model_state_dict)
# save
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
print(f"Generating {pytorch_weights_dump_path}")
torch.save(model_state_dict, pytorch_weights_dump_path)
print("Conversion is done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--biogpt_checkpoint_path",
default=None,
type=str,
required=True,
help=(
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
" bpecodes, etc."
),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -1,177 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BiT checkpoints from the timm library."""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_config(model_name):
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
conv_layer = "std_conv" if "bit" in model_name else False
# note that when using BiT as backbone for ViT-hybrid checkpoints,
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
# config.conv_layer = "std_conv_same"
config = BitConfig(
conv_layer=conv_layer,
num_labels=1000,
id2label=id2label,
label2id=label2id,
)
return config
def rename_key(name):
if "stem.conv" in name:
name = name.replace("stem.conv", "bit.embedder.convolution")
if "blocks" in name:
name = name.replace("blocks", "layers")
if "head.fc" in name:
name = name.replace("head.fc", "classifier.1")
if name.startswith("norm"):
name = "bit." + name
if "bit" not in name and "classifier" not in name:
name = "bit.encoder." + name
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our BiT structure.
"""
# define default BiT configuration
config = get_config(model_name)
# load original model from timm
timm_model = create_model(model_name, pretrained=True)
timm_model.eval()
# load state_dict of original model
state_dict = timm_model.state_dict()
for key in state_dict.copy():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
# load HuggingFace model
model = BitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# create image processor
transform = create_transform(**resolve_data_config({}, model=timm_model))
timm_transforms = transform.transforms
pillow_resamplings = {
"bilinear": PILImageResampling.BILINEAR,
"bicubic": PILImageResampling.BICUBIC,
"nearest": PILImageResampling.NEAREST,
}
processor = BitImageProcessor(
do_resize=True,
size={"shortest_edge": timm_transforms[0].size},
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
do_center_crop=True,
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
do_normalize=True,
image_mean=timm_transforms[-1].mean.tolist(),
image_std=timm_transforms[-1].std.tolist(),
)
image = prepare_img()
timm_pixel_values = transform(image).unsqueeze(0)
pixel_values = processor(image, return_tensors="pt").pixel_values
# verify pixel values
assert torch.allclose(timm_pixel_values, pixel_values)
# verify logits
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
print("Logits:", logits[0, :3])
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model {model_name} and processor to the hub")
model.push_to_hub(f"ybelkada/{model_name}")
processor.push_to_hub(f"ybelkada/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="resnetv2_50x1_bitm",
type=str,
help="Name of the BiT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub.",
)
args = parser.parse_args()
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,114 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Blenderbot checkpoint."""
import argparse
import torch
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
PATTERNS = [
["attention", "attn"],
["encoder_attention", "encoder_attn"],
["q_lin", "q_proj"],
["k_lin", "k_proj"],
["v_lin", "v_proj"],
["out_lin", "out_proj"],
["norm_embeddings", "layernorm_embedding"],
["position_embeddings", "embed_positions"],
["embeddings", "embed_tokens"],
["ffn.lin", "fc"],
]
def rename_state_dict_key(k):
if k == "embeddings.weight":
return "shared.weight"
for parlai_name, hf_name in PATTERNS:
k = k.replace(parlai_name, hf_name)
if k.startswith("encoder"):
k = k.replace(".attn", ".self_attn")
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "final_layer_norm")
elif k.startswith("decoder"):
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "encoder_attn_layer_norm")
k = k.replace("norm3", "final_layer_norm")
return k
def rename_layernorm_keys(sd):
keys = [
"model.encoder.layernorm_embedding.weight",
"model.encoder.layernorm_embedding.bias",
"model.decoder.layernorm_embedding.weight",
"model.decoder.layernorm_embedding.bias",
]
for k in keys:
v = sd.pop(k)
new_k = k.replace("layernorm_embedding", "layer_norm")
assert new_k not in sd
sd[new_k] = v
IGNORE_KEYS = ["START"]
@torch.no_grad()
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
model = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
sd = model["model"]
cfg = BlenderbotConfig.from_json_file(config_json_path)
m = BlenderbotForConditionalGeneration(cfg)
valid_keys = m.model.state_dict().keys()
failures = []
mapping = {}
for k, v in sd.items():
if k in IGNORE_KEYS:
continue
new_k = rename_state_dict_key(k)
if new_k not in valid_keys:
failures.append([k, new_k])
else:
mapping[new_k] = v
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
rename_layernorm_keys(sd)
m.model.load_state_dict(mapping, strict=True)
m.half()
m.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
parser.add_argument(
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
)
args = parser.parse_args()
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)

View File

@ -1,191 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import requests
import torch
# git clone https://github.com/salesforce/BLIP.git
from models.blip import blip_decoder
from models.blip_itm import blip_itm
from models.blip_vqa import blip_vqa
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import (
BertTokenizer,
BlipConfig,
BlipForConditionalGeneration,
BlipForImageTextRetrieval,
BlipForQuestionAnswering,
)
def load_demo_image(image_size, device):
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
image = transform(raw_image).unsqueeze(0).to(device)
return image
def rename_key(key):
if "visual_encoder" in key:
key = re.sub("visual_encoder*", "vision_model.encoder", key)
if "blocks" in key:
key = re.sub(r"blocks", "layers", key)
if "attn" in key:
key = re.sub(r"attn", "self_attn", key)
if "norm1" in key:
key = re.sub(r"norm1", "layer_norm1", key)
if "norm2" in key:
key = re.sub(r"norm2", "layer_norm2", key)
if "encoder.norm" in key:
key = re.sub(r"encoder.norm", "post_layernorm", key)
if "encoder.patch_embed.proj" in key:
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
if "encoder.pos_embed" in key:
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
if "encoder.cls_token" in key:
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
if "self_attn" in key:
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
return key
@torch.no_grad()
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = BlipConfig.from_pretrained(config_path)
else:
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = BlipForConditionalGeneration(config).eval()
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
pt_model = pt_model.eval()
modified_state_dict = pt_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_model.load_state_dict(modified_state_dict)
image_size = 384
image = load_demo_image(image_size=image_size, device="cpu")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
input_ids = tokenizer(["a picture of"]).input_ids
out = hf_model.generate(image, input_ids)
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
out = hf_model.generate(image)
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
if pytorch_dump_folder_path is not None:
hf_model.save_pretrained(pytorch_dump_folder_path)
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
model_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
)
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
vqa_model.eval()
modified_state_dict = vqa_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_vqa_model = BlipForQuestionAnswering(config)
hf_vqa_model.load_state_dict(modified_state_dict)
question = ["How many dogs are in this image?"]
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
answer = hf_vqa_model.generate(question_input_ids, image)
print(tokenizer.decode(answer[0]))
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
if pytorch_dump_folder_path is not None:
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
itm_model.eval()
modified_state_dict = itm_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_itm_model = BlipForImageTextRetrieval(config)
question = ["A picture of a woman with a dog sitting in a beach"]
question_input_ids = tokenizer(
question,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=35,
).input_ids
hf_itm_model.load_state_dict(modified_state_dict)
hf_itm_model.eval()
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
assert out[0].item() == 0.2110687494277954
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
if pytorch_dump_folder_path is not None:
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)

View File

@ -1,390 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert BLIP-2 checkpoints from the original repository.
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
"""
import argparse
import requests
import torch
# pip3 install salesforce-lavis
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
# to make sure we can compare both original and HF implementation in float32
from lavis.models import load_model_and_preprocess
from PIL import Image
from transformers import (
AutoTokenizer,
BertTokenizer,
Blip2Config,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Processor,
Blip2QFormerConfig,
Blip2VisionConfig,
BlipImageProcessor,
OPTConfig,
T5Config,
set_seed,
)
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
def load_demo_image():
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return image
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, model_name):
rename_keys = []
# fmt: off
# vision encoder
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
for i in range(config.vision_config.num_hidden_layers):
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
# QFormer
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
if "itm" in model_name:
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
rename_keys.append(("text_proj.weight", "text_projection.weight"))
rename_keys.append(("text_proj.bias", "text_projection.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def read_in_q_v_bias(state_dict, config):
for i in range(config.vision_config.num_hidden_layers):
# read in original q and v biases
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
# next, set bias in the state dict
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
def get_blip2_config(model_name, eos_token_id):
image_size = 364 if "coco" in model_name else 224
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
# seems like flan-T5 models don't have bos_token_id properly set?
if "opt-2.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
elif "opt-6.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
elif "t5-xl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "t5-xxl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "itm" in model_name:
text_config = {}
else:
raise ValueError("Model name not supported")
if "itm" in model_name:
config = Blip2Config(
vision_config=vision_config,
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
)
else:
config = Blip2Config(vision_config=vision_config, text_config=text_config)
return config, image_size
@torch.no_grad()
def convert_blip2_checkpoint(
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
):
"""
Copy/paste/tweak model's weights to Transformers design.
"""
if "opt" in model_name:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
elif "itm" in model_name:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
else:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
if "itm" in model_name:
eos_token_id = None
else:
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
if "itm" in model_name:
hf_model = Blip2ForImageTextRetrieval(config).eval()
else:
hf_model = Blip2ForConditionalGeneration(config).eval()
model_name_to_original = {
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
}
name, type = model_name_to_original[model_name]
# load original model
print("Loading original model...")
original_model, vis_processors, _ = load_model_and_preprocess(
name=name, model_type=type, is_eval=True, device=lavis_device
)
original_model.eval()
print("Done!")
# update state dict keys
state_dict = original_model.state_dict()
rename_keys = create_rename_keys(config, model_name)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
# some keys can be renamed efficiently
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("Qformer.bert"):
key = key.replace("Qformer.bert", "qformer")
if "attention.self" in key:
key = key.replace("self", "attention")
if "opt_proj" in key:
key = key.replace("opt_proj", "language_projection")
if "t5_proj" in key:
key = key.replace("t5_proj", "language_projection")
if key.startswith("opt"):
key = key.replace("opt", "language")
if key.startswith("t5"):
key = key.replace("t5", "language")
state_dict[key] = val
# read in qv biases
read_in_q_v_bias(state_dict, config)
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
assert len(missing_keys) == 0
if "itm" in model_name:
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
else:
assert unexpected_keys == ["qformer.embeddings.position_ids"]
image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
# create processor
image_processor = BlipImageProcessor(
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
)
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
# make sure processor creates exact same pixel values
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
original_model.to(lavis_device)
hf_model.to(hf_model_device)
if "itm" in model_name:
caption = "a large fountain spewing water into the air"
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=True,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
print("Looks ok!")
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=False,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
print("Looks ok!")
else:
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(pixel_values, input_ids, labels=labels).logits
assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])
# assert values
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")
print("Generating a caption...")
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
set_seed(42)
original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
)
outputs = hf_model.generate(
pixel_values,
input_ids,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1,
)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)
if pytorch_dump_folder_path is not None:
processor.save_pretrained(pytorch_dump_folder_path)
hf_model.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
processor.push_to_hub(f"nielsr/{model_name}")
hf_model.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = [
"blip2-opt-2.7b",
"blip2-opt-6.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
"blip2-itm-vit-g",
"blip2-itm-vit-g-coco",
]
parser.add_argument(
"--model_name",
default="blip2-opt-2.7b",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
)
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model and processor to the hub after converting",
)
# note: this script is tested on 2 GPUs, as models are compared in float32,
# which requires quite some memory. Hence loading both on a
# separate device is the easiest to compare
parser.add_argument(
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
parser.add_argument(
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
args = parser.parse_args()
convert_blip2_checkpoint(
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
)

View File

@ -1,254 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigScience BLOOM checkpoint."""
import argparse
import json
import os
import re
import torch
from transformers import BloomConfig, BloomModel
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging
logging.set_verbosity_info()
WEIGHTS_TO_AVERAGE_ENDSWITH = [
"word_embeddings_layernorm.weight",
"word_embeddings_layernorm.bias",
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
"ln_f.weight",
"ln_f.bias",
]
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
"mlp.dense_4h_to_h.weight",
"self_attention.dense.weight",
]
def layer_name_mapping(key, file):
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
# Handle first and last layers
layer_rename_map = {
"word_embeddings.weight": "word_embeddings.weight",
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
"weight": "ln_f.weight",
"bias": "ln_f.bias",
}
if key in layer_rename_map:
return layer_rename_map[key]
# Handle transformer blocks
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
layer_number -= 3
return f"h.{layer_number}." + key
def get_dtype_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def convert_bloom_checkpoint_to_pytorch(
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
):
# Construct model
if bloom_config_file == "":
config = BloomConfig()
else:
config = BloomConfig.from_json_file(bloom_config_file)
if shard_model:
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
index_dict = {"weight_map": {}, "metadata": {}}
total_size = 0
missing_keys = None
config = BloomConfig()
for j, file in enumerate(file_names):
print(f"Processing file: {file}")
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors:
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights across TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors:
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
torch.save(
tensors,
os.path.join(
pytorch_dump_folder_path,
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin",
),
)
for key in tensors:
value = tensors[key]
total_size += value.numel() * get_dtype_size(value.dtype)
if key not in index_dict["weight_map"]:
index_dict["weight_map"][key] = (
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin"
)
config = BloomConfig()
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
index_dict["metadata"]["total_size"] = total_size
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
f.write(json_config)
else:
model = BloomModel(config)
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
missing_keys = None
for i, file in enumerate(file_names):
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors:
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights across TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors:
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
other_keys = model.load_state_dict(tensors, strict=False)
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
if missing_keys is None:
missing_keys = set(other_keys.missing_keys)
else:
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
assert not missing_keys, f"The keys {missing_keys} are missing"
# Save pytorch-model
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
if config.torch_dtype is not None:
model = model.to(config.torch_dtype)
torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--bloom_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the Megatron-LM checkpoint path.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--bloom_config_file",
default="",
type=str,
help=(
"An optional config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--shard_model",
action="store_true",
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
)
parser.add_argument(
"--pretraining_tp",
default=4,
type=int,
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
)
args = parser.parse_args()
convert_bloom_checkpoint_to_pytorch(
args.bloom_checkpoint_path,
args.bloom_config_file,
args.pytorch_dump_folder_path,
args.shard_model,
args.pretraining_tp,
)

View File

@ -1,145 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Bros checkpoints."""
import argparse
import bros # original repo
import torch
from transformers import BrosConfig, BrosModel, BrosProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_configs(model_name):
bros_config = BrosConfig.from_pretrained(model_name)
return bros_config
def remove_ignore_keys_(state_dict):
ignore_keys = [
"embeddings.bbox_sinusoid_emb.inv_freq",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(name):
if name == "embeddings.bbox_projection.weight":
name = "bbox_embeddings.bbox_projection.weight"
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
return name
def convert_state_dict(orig_state_dict, model):
# rename keys
for key in orig_state_dict.copy():
val = orig_state_dict.pop(key)
orig_state_dict[rename_key(key)] = val
# remove ignore keys
remove_ignore_keys_(orig_state_dict)
return orig_state_dict
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
# load original model
original_model = bros.BrosModel.from_pretrained(model_name).eval()
# load HuggingFace Model
bros_config = get_configs(model_name)
model = BrosModel.from_pretrained(model_name, config=bros_config)
model.eval()
state_dict = original_model.state_dict()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
# verify results
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
bbox = torch.tensor(
[
[
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
]
]
)
processor = BrosProcessor.from_pretrained(model_name)
encoding = processor("His name is Rocco.", return_tensors="pt")
encoding["bbox"] = bbox
original_hidden_states = original_model(**encoding).last_hidden_state
# pixel_values = processor(image, return_tensors="pt").pixel_values
last_hidden_states = model(**encoding).last_hidden_state
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="jinho8345/bros-base-uncased",
required=False,
type=str,
help="Name of the original model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
required=False,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model and processor to the 🤗 hub.",
)
args = parser.parse_args()
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,59 +0,0 @@
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)

View File

@ -1,65 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CANINE checkpoint."""
import argparse
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
# Initialize PyTorch model
config = CanineConfig()
model = CanineModel(config)
model.eval()
print(f"Building PyTorch model from configuration: {config}")
# Load weights from tf checkpoint
load_tf_weights_in_canine(model, config, tf_checkpoint_path)
# Save pytorch-model (weights and configuration)
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# Save tokenizer files
tokenizer = CanineTokenizer()
print(f"Save tokenizer files to {pytorch_dump_path}")
tokenizer.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
)
parser.add_argument(
"--pytorch_dump_path",
default=None,
type=str,
required=True,
help="Path to a folder where the PyTorch model will be placed.",
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)

View File

@ -1,478 +0,0 @@
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import requests
import torch
import yaml
from accelerate import init_empty_weights
from PIL import Image
from transformers import (
ChameleonConfig,
ChameleonForConditionalGeneration,
ChameleonImageProcessor,
ChameleonProcessor,
)
try:
from transformers import LlamaTokenizerFast
except ImportError:
raise ValueError(
"Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
"Update your `tokenizers` library and re-run the tokenizer conversion."
)
"""
Sample usage:
```
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
--input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
NUM_SHARDS = {
"7B": 1,
"30B": 4,
}
VOCAB_SIZE = 65536
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
os.makedirs(model_path, exist_ok=True)
input_model_path = os.path.join(input_base_path, "models", model_size.lower())
params_path = os.path.join(input_model_path, "params.json")
consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
params = read_json(params_path)
if os.path.isfile(consolidate_params_path):
params = {**params, **read_json(consolidate_params_path)}
num_shards = NUM_SHARDS[model_size]
model_parallel_size = params["model_parallel_size"]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
swin_norm = params["swin_norm"]
if base > 10000.0:
max_position_embeddings = 16384
else:
# Depending on the Chameleon version, the default max_position_embeddings has different values.
if chameleon_version == 1:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {chameleon_version} of chameleon is not supported yet. "
"Current supported versions of chameleon are [1]."
)
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = None
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
possible_path = os.path.join(input_model_path, possible_name)
if os.path.exists(possible_path):
loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
break
assert loaded is not None
else:
# Sharded
loaded = [
torch.load(
os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
)
for i in range(num_shards)
]
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
# Load weights to the state dict
state_dict = {}
for layer_i in range(n_layers):
if num_shards == 1:
# Unsharded
state_dict.update(
{
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"],
n_heads=num_key_value_heads,
dim1=key_value_dim,
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
f"layers.{layer_i}.ffn_norm.weight"
],
}
)
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(n_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(n_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(num_key_value_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(num_key_value_heads, 0)
)
else:
# Sharded
state_dict.update(
{
f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
[l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
).mean(dim=0),
f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
[l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
).mean(dim=0),
}
)
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
n_heads=num_key_value_heads,
dim1=key_value_dim,
)
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(n_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(n_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(num_key_value_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(num_key_value_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
if num_shards == 1:
# Unsharded
state_dict.update(
{
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
)
else:
state_dict.update(
{
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
)
# Load VQGAN weights
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
for k, v in vqgan_state_dict.items():
if "decoder" in k:
continue # we dont do image generation yet
state_dict[f"model.vqmodel.{k}"] = v
# Write configs
ffn_dim_multiplier = params.get("ffn_dim_multiplier", 1)
multiple_of = params.get("multiple_of", 256)
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
tokenizer_config = json.load(tokenizer_file)
vocabulary_map = tokenizer_config["model"]["vocab"]
vocabulary_map["<image>"] = vocabulary_map[
"<reserved08707>"
] # use a reserved token instead of adding a new one
del vocabulary_map["<reserved08707>"]
for token in tokenizer_config["added_tokens"]:
if token["content"] == "<reserved08707>":
token["content"] = "<image>"
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
json.dump(tokenizer_config, f) # save the new file to init tokenizer later
vq_keys_to_replace = [
("ch", "base_channels"),
("out_ch", "out_channels"),
("n_embed", "num_embeddings"),
("ch_mult", "channel_multiplier"),
("double_z", "double_latent"),
("z_channels", "latent_channels"),
]
with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
vq_config.update(**vq_config["ddconfig"])
for old, new in vq_keys_to_replace:
vq_config[new] = vq_config[old]
del vq_config["ddconfig"]
del vq_config["ckpt_path"]
del vq_config["lossconfig"]
config = ChameleonConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=VOCAB_SIZE,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
model_parallel_size=model_parallel_size,
swin_norm=swin_norm,
vq_config=vq_config,
vocabulary_map=vocabulary_map,
)
with init_empty_weights():
model = ChameleonForConditionalGeneration(config)
model.load_state_dict(state_dict, assign=True, strict=False)
model.save_pretrained(model_path, safe_serialization=True)
# Load and save the processor
tokenizer = LlamaTokenizerFast(
tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
)
tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
tokenizer.pad_token_id = 1 # assign <pad> to special pad_token
image_processor = ChameleonImageProcessor()
processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
del vqgan_state_dict
gc.collect()
# Short inference on a few examples to check if generation makes sense
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
print("Loading the checkpoint in a Chameleon model...")
print("*" * 100)
model = ChameleonForConditionalGeneration.from_pretrained(
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained(model_path)
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
image = Image.open(
requests.get(
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
).raw
)
inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
length = inputs.input_ids.shape[1]
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
print(f"Generation for single-image: {generated_text}")
print("*" * 100)
# Multi-image example
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
length = inputs.input_ids.shape[1]
out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
print(f"Generation for multi-image: {generated_text}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Chameleon weights",
)
parser.add_argument(
"--model_size",
choices=["7B", "30B"],
help=""
" models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, check out the original repo: https://github.com/facebookresearch/chameleon",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model",
)
parser.add_argument(
"--test_inference",
action="store_true",
help="Whether to load the model for generation to test it's converted correctly.",
)
# Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--chameleon_version",
choices=[1],
default=1,
type=int,
help="Version of the Chameleon model to convert",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
chameleon_version=args.chameleon_version,
)
if __name__ == "__main__":
main()

View File

@ -1,134 +0,0 @@
# coding=utf-8
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from transformers import ChineseCLIPConfig, ChineseCLIPModel
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
hf_attn_layer.q_proj.weight.data = q_proj
hf_attn_layer.q_proj.bias.data = q_proj_bias
hf_attn_layer.k_proj.weight.data = k_proj
hf_attn_layer.k_proj.bias.data = k_proj_bias
hf_attn_layer.v_proj.weight.data = v_proj
hf_attn_layer.v_proj.bias.data = v_proj_bias
hf_attn_layer.out_proj.weight.data = out_proj_weights
hf_attn_layer.out_proj.bias.data = out_proj_bias
def copy_mlp(hf_mlp, pt_weights, prefix):
copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
def copy_linear(hf_linear, pt_weights, prefix):
hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
def copy_layer(hf_layer, pt_weights, prefix):
# copy layer norms
copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
# copy MLP
copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
# copy attn
copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
def copy_layers(hf_layers, pt_weights, prefix):
for layer_id, hf_layer in enumerate(hf_layers):
copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
def copy_text_model_and_projection(hf_model, pt_weights):
# copy projection
hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
# copy text encoder
for name, param in hf_model.text_model.named_parameters():
param.data = pt_weights[f"bert.{name}"].data
def copy_vision_model_and_projection(hf_model, pt_weights):
# copy projection
hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
# copy layer norms
copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
# copy embeddings
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
# copy encoder
copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
@torch.no_grad()
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
config = ChineseCLIPConfig.from_pretrained(config_path)
hf_model = ChineseCLIPModel(config).eval()
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
copy_text_model_and_projection(hf_model, pt_weights)
copy_vision_model_and_projection(hf_model, pt_weights)
hf_model.logit_scale.data = pt_weights["logit_scale"].data
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output folder storing converted hf PyTorch model.",
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
)
parser.add_argument(
"--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
)
args = parser.parse_args()
convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
print("The conversion is finished!")

View File

@ -1,133 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
from laion_clap import CLAP_Module
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
KEYS_TO_MODIFY_MAPPING = {
"text_branch": "text_model",
"audio_branch": "audio_model.audio_encoder",
"attn": "attention.self",
"self.proj": "output.dense",
"attention.self_mask": "attn_mask",
"mlp.fc1": "intermediate.dense",
"mlp.fc2": "output.dense",
"norm1": "layernorm_before",
"norm2": "layernorm_after",
"bn0": "batch_norm",
}
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
def init_clap(checkpoint_path, model_type, enable_fusion=False):
model = CLAP_Module(
amodel=model_type,
enable_fusion=enable_fusion,
)
model.load_ckpt(checkpoint_path)
return model
def get_config_from_original(clap_model):
audio_config = {
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
"depths": clap_model.model.audio_branch.depths,
"hidden_size": clap_model.model.audio_projection[0].in_features,
}
text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
return ClapConfig(audio_config=audio_config, text_config=text_config)
def rename_state_dict(state_dict):
model_state_dict = {}
sequential_layers_pattern = r".*sequential.(\d+).*"
text_projection_pattern = r".*_projection.(\d+).*"
for key, value in state_dict.items():
# check if any key needs to be modified
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if re.match(sequential_layers_pattern, key):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
# Because in CLAP they use `nn.Sequential`...
transformers_projection_layer = 1 if projecton_layer == 0 else 2
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
if "audio" and "qkv" in key:
# split qkv into query key and value
mixed_qkv = value
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
model_state_dict[key.replace("qkv", "query")] = query_layer
model_state_dict[key.replace("qkv", "key")] = key_layer
model_state_dict[key.replace("qkv", "value")] = value_layer
else:
model_state_dict[key] = value
return model_state_dict
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
clap_model.eval()
state_dict = clap_model.model.state_dict()
state_dict = rename_state_dict(state_dict)
transformers_config = get_config_from_original(clap_model)
transformers_config.audio_config.enable_fusion = enable_fusion
model = ClapModel(transformers_config)
# ignore the spectrogram embedding layer
model.load_state_dict(state_dict, strict=False)
model.save_pretrained(pytorch_dump_folder_path)
transformers_config.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
args = parser.parse_args()
convert_clap_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
)

View File

@ -1,156 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from clip import load
from transformers import CLIPConfig, CLIPModel
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
out_proj_weights = pt_attn_layer.out_proj.weight
out_proj_bias = pt_attn_layer.out_proj.bias
hf_attn_layer.q_proj.weight.data = q_proj
hf_attn_layer.q_proj.bias.data = q_proj_bias
hf_attn_layer.k_proj.weight.data = k_proj
hf_attn_layer.k_proj.bias.data = k_proj_bias
hf_attn_layer.v_proj.weight.data = v_proj
hf_attn_layer.v_proj.bias.data = v_proj_bias
hf_attn_layer.out_proj.weight = out_proj_weights
hf_attn_layer.out_proj.bias = out_proj_bias
def copy_mlp(hf_mlp, pt_mlp):
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
def copy_linear(hf_linear, pt_linear):
hf_linear.weight = pt_linear.weight
hf_linear.bias = pt_linear.bias
def copy_layer(hf_layer, pt_layer):
# copy layer norms
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
# copy MLP
copy_mlp(hf_layer.mlp, pt_layer.mlp)
# copy attn
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
def copy_layers(hf_layers, pt_layers):
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
copy_layer(hf_layer, pt_layer)
def copy_encoder(hf_encoder, pt_model):
# copy embeds
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
# copy layer norm
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
# copy hidden layers
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
def copy_text_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
# copy text encoder
copy_encoder(hf_model.text_model, pt_model)
def copy_vison_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
# copy layer norms
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
# copy embeds
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
# copy encoder
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
@torch.no_grad()
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = CLIPConfig.from_pretrained(config_path)
else:
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = CLIPModel(config).eval()
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
pt_model = pt_model.eval()
copy_text_model_and_projection(hf_model, pt_model)
copy_vison_model_and_projection(hf_model, pt_model)
hf_model.logit_scale = pt_model.logit_scale
# Use `eos_token` so the example is more meaningful
input_ids = torch.tensor(
[
[config.text_config.bos_token_id]
+ list(range(3, 77))
+ [config.text_config.eos_token_id]
+ [config.text_config.pad_token_id]
]
)
pixel_values = torch.randn(1, 3, 224, 224)
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
hf_logits_per_image = hf_outputs.logits_per_image
hf_logits_per_text = hf_outputs.logits_per_text
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)

View File

@ -1,264 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
CLIPSegConfig,
CLIPSegForImageSegmentation,
CLIPSegProcessor,
CLIPSegTextConfig,
CLIPSegVisionConfig,
CLIPTokenizer,
ViTImageProcessor,
)
def get_clipseg_config(model_name):
text_config = CLIPSegTextConfig()
vision_config = CLIPSegVisionConfig(patch_size=16)
use_complex_transposed_convolution = "refined" in model_name
reduce_dim = 16 if "rd16" in model_name else 64
config = CLIPSegConfig.from_text_vision_configs(
text_config,
vision_config,
use_complex_transposed_convolution=use_complex_transposed_convolution,
reduce_dim=reduce_dim,
)
return config
def rename_key(name):
# update prefixes
if "clip_model" in name:
name = name.replace("clip_model", "clip")
if "transformer" in name:
if "visual" in name:
name = name.replace("visual.transformer", "vision_model")
else:
name = name.replace("transformer", "text_model")
if "resblocks" in name:
name = name.replace("resblocks", "encoder.layers")
if "ln_1" in name:
name = name.replace("ln_1", "layer_norm1")
if "ln_2" in name:
name = name.replace("ln_2", "layer_norm2")
if "c_fc" in name:
name = name.replace("c_fc", "fc1")
if "c_proj" in name:
name = name.replace("c_proj", "fc2")
if "attn" in name and "self" not in name:
name = name.replace("attn", "self_attn")
# text encoder
if "token_embedding" in name:
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
if "positional_embedding" in name and "visual" not in name:
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
if "ln_final" in name:
name = name.replace("ln_final", "text_model.final_layer_norm")
# vision encoder
if "visual.class_embedding" in name:
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
if "visual.conv1" in name:
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
if "visual.positional_embedding" in name:
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
if "visual.ln_pre" in name:
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
if "visual.ln_post" in name:
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
# projection layers
if "visual.proj" in name:
name = name.replace("visual.proj", "visual_projection.weight")
if "text_projection" in name:
name = name.replace("text_projection", "text_projection.weight")
# decoder
if "trans_conv" in name:
name = name.replace("trans_conv", "transposed_convolution")
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
name = "decoder." + name
if "blocks" in name:
name = name.replace("blocks", "decoder.layers")
if "linear1" in name:
name = name.replace("linear1", "mlp.fc1")
if "linear2" in name:
name = name.replace("linear2", "mlp.fc2")
if "norm1" in name and "layer_" not in name:
name = name.replace("norm1", "layer_norm1")
if "norm2" in name and "layer_" not in name:
name = name.replace("norm2", "layer_norm2")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy():
val = orig_state_dict.pop(key)
if key.startswith("clip_model") and "attn.in_proj" in key:
key_split = key.split(".")
if "visual" in key:
layer_num = int(key_split[4])
dim = config.vision_config.hidden_size
prefix = "vision_model"
else:
layer_num = int(key_split[3])
dim = config.text_config.hidden_size
prefix = "text_model"
if "weight" in key:
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
dim : dim * 2, :
]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
else:
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
elif "self_attn" in key and "out_proj" not in key:
key_split = key.split(".")
layer_num = int(key_split[1])
dim = config.reduce_dim
if "weight" in key:
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
else:
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
else:
new_name = rename_key(key)
if "visual_projection" in new_name or "text_projection" in new_name:
val = val.T
orig_state_dict[new_name] = val
return orig_state_dict
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
config = get_clipseg_config(model_name)
model = CLIPSegForImageSegmentation(config)
model.eval()
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# remove some keys
for key in state_dict.copy():
if key.startswith("model"):
state_dict.pop(key, None)
# rename some keys
state_dict = convert_state_dict(state_dict, config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
raise ValueError(f"Missing keys that are not expected: {missing_keys}")
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
raise ValueError(f"Unexpected keys: {unexpected_keys}")
image_processor = ViTImageProcessor(size=352)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
image = prepare_img()
text = ["a glass", "something to fill", "wood", "a jar"]
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# verify values
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
if model_name == "clipseg-rd64-refined":
expected_masks_slice = torch.tensor(
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
)
elif model_name == "clipseg-rd64":
expected_masks_slice = torch.tensor(
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
)
elif model_name == "clipseg-rd16":
expected_masks_slice = torch.tensor(
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
)
else:
raise ValueError(f"Model name {model_name} not supported.")
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and processor for {model_name} to the hub")
model.push_to_hub(f"CIDAS/{model_name}")
processor.push_to_hub(f"CIDAS/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="clipseg-rd64",
type=str,
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
help=(
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
" reduce dimension)"
),
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
type=str,
help=(
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
" the decoder weights."
),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,234 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Weights conversion script for CLVP
"""
import argparse
import os
import torch
from huggingface_hub import hf_hub_download
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
_MODELS = {
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
}
dim = 1024
sub_dim = dim // 16
CLVP_ENCODERS_MAPPING = {
"text_transformer.transformer.attn_layers": "text_encoder_model",
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
"to_text_latent": "text_encoder_model.projection",
"to_speech_latent": "speech_encoder_model.projection",
"text_emb": "text_encoder_model.token_embedding",
"speech_emb": "speech_encoder_model.token_embedding",
"1.wrap.net.0": "mlp.fc1",
"1.wrap.net.3": "mlp.fc2",
"1.wrap": "self_attn",
"to_out": "out_proj",
"to_q": "q_proj",
"to_k": "k_proj",
"to_v": "v_proj",
"temperature": "logit_scale",
}
CLVP_DECODER_MAPPING = {
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
"mel_attn_blocks": "group_norms",
".norm.weight": ".weight",
".norm.bias": ".bias",
"text_embedding": "conditioning_encoder.text_token_embedding",
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
"final_norm": "speech_decoder_model.final_norm",
"mel_head": "speech_decoder_model.lm_head",
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
"gpt.h": "speech_decoder_model.model.decoder.layers",
"ln_1": "input_layernorm",
"ln_2": "post_attention_layernorm",
}
def update_index(present_index):
if present_index % 2 == 0:
return int(present_index / 2)
else:
return int((present_index - 1) / 2)
def convert_encoder_weights(original_weights):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
for original_key in original_weights_keys:
updated_key = original_key
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
if "0.0.g" in updated_key:
present_index = updated_key.split(".")[4]
if int(present_index) % 2 == 0:
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
else:
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
if "transformer.attn_layers.layers" in updated_key:
present_index = updated_key.split(".")[4]
updated_index = update_index(int(present_index))
updated_key = updated_key.replace(
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
)
for k, v in CLVP_ENCODERS_MAPPING.items():
if k in updated_key:
updated_key = updated_key.replace(k, v)
converted_weights[updated_key] = original_weights.pop(original_key)
return converted_weights
def convert_decoder_weights(original_weights):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
for original_key in original_weights_keys:
updated_key = original_key
if len(updated_key.split(".")) > 3:
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
# for decoder attention
if "attn.c_attn" in updated_key:
if attr == "weight":
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
else:
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
continue
if "attn.c_proj" in updated_key:
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
original_weights[updated_key].squeeze(-1).T
)
continue
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
original_weights.pop(updated_key)
continue
# conditional encoder attention
if "qkv" in updated_key:
if attr == "weight":
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
else:
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
indices = torch.arange(dim)
index1, index2, index3 = (
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
[slice1[index1], slice2[index3], slice3[index2]],
axis=0,
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
[slice1[index2], slice2[index1], slice3[index3]],
axis=0,
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
[slice1[index3], slice2[index2], slice3[index1]],
axis=0,
)
continue
if "proj_out" in updated_key:
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
updated_key
].squeeze(-1)
continue
for k, v in CLVP_DECODER_MAPPING.items():
if k in updated_key:
updated_key = updated_key.replace(k, v)
converted_weights[updated_key] = original_weights.pop(original_key)
return converted_weights
def _download(url: str, root: str):
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
hf_hub_download(
repo_id=repo_id,
filename=filename,
force_filename=root,
local_dir_use_symlinks=False,
)
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
converted_checkpoint = {}
for each_model_name, each_model_url in _MODELS.items():
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
if not os.path.exists(each_model_path):
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
_download(url=each_model_url, root=each_model_path)
if each_model_name == "clvp":
clvp_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
else:
decoder_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
# Converting the weights
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
model = ClvpModelForConditionalGeneration(config)
model.load_state_dict(converted_checkpoint, strict=True)
model.save_pretrained(pytorch_dump_folder_path)
print(f"Model saved at {pytorch_dump_folder_path}!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument(
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model. (Please enter full path)",
)
args = parser.parse_args()
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -1,214 +0,0 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColPali weights from the original repository to the HF model format.
Original repository: https://github.com/illuin-tech/colpali.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf-internal \
--push_to_hub
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.3-merged \
--revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.3-hf \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colpali import ColPaliForRetrieval
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
new_state_dict = {}
for key, value in state_dict.items():
new_key = key
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
if key.startswith("model."):
new_key = key.replace("model.", "vlm.", 1)
new_state_dict[new_key] = value
return new_state_dict
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
"model.language_model.model.embed_tokens.weight"
].clone()
return original_state_dict
@torch.no_grad()
def convert_colpali_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColPaliConfig(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colpali"
config.is_composition = False
# Load the untrained model
model = ColPaliForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# Tie the weights (following ColPali's `__init__`` step)
if model.vlm.language_model._tied_weights_keys is not None:
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
# Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColPali model to the HF model format.
Example usage:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colpali_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@ -1,212 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColQwen2 weights from the original repository to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Original repository: https://github.com/illuin-tech/colqwen2.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colqwen2 import ColQwen2ForRetrieval
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
return original_state_dict
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
new_state_dict: dict[str, Any] = {}
for key, value in state_dict.items():
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
else:
# The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
# to all remaining keys.
if key.startswith("model."):
key = key.replace("model.", "model.language_model.")
if key.startswith("visual."):
key = key.replace("visual.", "model.visual.")
new_key = "vlm." + key
new_state_dict[new_key] = value
return new_state_dict
@torch.no_grad()
def convert_colqwen2_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColQwen2Config(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colqwen2"
config.is_composition = False
# Load the untrained model
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# # Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColQwen2 model to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Example usage:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colqwen2_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@ -1,324 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Conditional DETR checkpoints."""
import argparse
import json
from collections import OrderedDict
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
ConditionalDetrConfig,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
for i in range(6):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
# q, k, v projections in self/cross-attention in decoder for conditional DETR
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
)
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
)
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
)
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
# for conditional DETR, also convert reference point head and query scale MLP
rename_keys.extend(
[
("input_proj.weight", "input_projection.weight"),
("input_proj.bias", "input_projection.bias"),
("query_embed.weight", "query_position_embeddings.weight"),
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
("class_embed.weight", "class_labels_classifier.weight"),
("class_embed.bias", "class_labels_classifier.bias"),
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
]
)
def rename_key(state_dict, old, new):
val = state_dict.pop(old)
state_dict[new] = val
def rename_backbone_keys(state_dict):
new_state_dict = OrderedDict()
for key, value in state_dict.items():
if "backbone.0.body" in key:
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
def read_in_q_k_v(state_dict, is_panoptic=False):
prefix = ""
if is_panoptic:
prefix = "conditional_detr."
# first: transformer encoder
for i in range(6):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
"""
# load default config
config = ConditionalDetrConfig()
# set backbone and dilation attributes
if "resnet101" in model_name:
config.backbone = "resnet101"
if "dc5" in model_name:
config.dilation = True
is_panoptic = "panoptic" in model_name
if is_panoptic:
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
# load image processor
format = "coco_panoptic" if is_panoptic else "coco_detection"
image_processor = ConditionalDetrImageProcessor(format=format)
# prepare image
img = prepare_img()
encoding = image_processor(images=img, return_tensors="pt")
pixel_values = encoding["pixel_values"]
logger.info(f"Converting model {model_name}...")
# load original model from torch hub
conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
state_dict = conditional_detr.state_dict()
# rename keys
for src, dest in rename_keys:
if is_panoptic:
src = "conditional_detr." + src
rename_key(state_dict, src, dest)
state_dict = rename_backbone_keys(state_dict)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
prefix = "conditional_detr.model." if is_panoptic else "model."
for key in state_dict.copy():
if is_panoptic:
if (
key.startswith("conditional_detr")
and not key.startswith("class_labels_classifier")
and not key.startswith("bbox_predictor")
):
val = state_dict.pop(key)
state_dict["conditional_detr.model" + key[4:]] = val
elif "class_labels_classifier" in key or "bbox_predictor" in key:
val = state_dict.pop(key)
state_dict["conditional_detr." + key] = val
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
continue
else:
val = state_dict.pop(key)
state_dict[prefix + key] = val
else:
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
val = state_dict.pop(key)
state_dict[prefix + key] = val
# finally, create HuggingFace model and load state dict
model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
# verify our conversion
original_outputs = conditional_detr(pixel_values)
outputs = model(pixel_values)
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
if is_panoptic:
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
# Save model and image processor
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="conditional_detr_resnet50",
type=str,
help="Name of the CONDITIONAL_DETR model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)

Some files were not shown because too many files have changed in this diff Show More