mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
27 Commits
contributi
...
v4.55.4
Author | SHA1 | Date | |
---|---|---|---|
d79b2d981f | |||
90792b730a | |||
a03df6acd4 | |||
170b2708cb | |||
7dbc054e2a | |||
c097a43898 | |||
663cbb0d04 | |||
c7bd5350f0 | |||
e75d67ec39 | |||
d7f67d2006 | |||
acf295aec3 | |||
aaa3169aa2 | |||
ea2eee0bc8 | |||
956be23fff | |||
79a9ffc520 | |||
99404c7098 | |||
0d6908038c | |||
b8e97fbfd2 | |||
586b6e693b | |||
95ae07d11f | |||
0d9032ae71 | |||
1d42803aac | |||
382717e543 | |||
cc98f42d22 | |||
d2f7266367 | |||
daab2db33f | |||
06f8004e5c |
@ -511,6 +511,8 @@
|
||||
title: GPT2
|
||||
- local: model_doc/gpt_bigcode
|
||||
title: GPTBigCode
|
||||
- local: model_doc/gpt_oss
|
||||
title: GptOss
|
||||
- local: model_doc/gptsan-japanese
|
||||
title: GPTSAN Japanese
|
||||
- local: model_doc/gpt-sw3
|
||||
@ -617,8 +619,6 @@
|
||||
title: OLMoE
|
||||
- local: model_doc/open-llama
|
||||
title: Open-Llama
|
||||
- local: model_doc/openai_moe
|
||||
title: OpenAIMoe
|
||||
- local: model_doc/opt
|
||||
title: OPT
|
||||
- local: model_doc/pegasus
|
||||
|
@ -65,6 +65,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
|
||||
[[autodoc]] HqqConfig
|
||||
|
||||
## Mxfp4Config
|
||||
|
||||
[[autodoc]] Mxfp4Config
|
||||
|
||||
## FbgemmFp8Config
|
||||
|
||||
[[autodoc]] FbgemmFp8Config
|
||||
|
@ -24,11 +24,11 @@ rendered properly in your Markdown viewer.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# OpenAIMoE
|
||||
# GptOss
|
||||
|
||||
## Overview
|
||||
|
||||
The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
The GptOss model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
@ -43,16 +43,16 @@ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
|
||||
## OpenAIMoeConfig
|
||||
## GptOssConfig
|
||||
|
||||
[[autodoc]] OpenAIMoeConfig
|
||||
[[autodoc]] GptOssConfig
|
||||
|
||||
## OpenAIMoeModel
|
||||
## GptOssModel
|
||||
|
||||
[[autodoc]] OpenAIMoeModel
|
||||
[[autodoc]] GptOssModel
|
||||
- forward
|
||||
|
||||
## OpenAIMoeForCausalLM
|
||||
## GptOssForCausalLM
|
||||
|
||||
[[autodoc]] OpenAIMoeForCausalLM
|
||||
[[autodoc]] GptOssForCausalLM
|
||||
- forward
|
@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets[audio]>=1.14.0",
|
||||
# "evaluate",
|
||||
# "librosa",
|
||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
# "datasets>=1.8.0",
|
||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate>=0.12.0",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate>=0.12.0",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
# "datasets>=1.8.0",
|
||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
# "datasets>=1.8.0",
|
||||
@ -56,7 +56,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "torch>=1.5.0",
|
||||
# "torchvision>=0.6.0",
|
||||
# "datasets>=1.8.0",
|
||||
@ -61,7 +61,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "timm",
|
||||
# "datasets",
|
||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "timm",
|
||||
# "datasets",
|
||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -69,7 +69,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -72,7 +72,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -74,7 +74,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "torch >= 1.3",
|
||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
# "protobuf",
|
||||
@ -57,7 +57,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
# "protobuf",
|
||||
@ -65,7 +65,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "timm",
|
||||
# "datasets>=4.0",
|
||||
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "albumentations >= 1.4.16",
|
||||
# "timm",
|
||||
# "datasets>=4.0",
|
||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets >= 2.0.0",
|
||||
# "torch >= 1.3",
|
||||
# "accelerate",
|
||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets >= 2.0.0",
|
||||
# "torch >= 1.3",
|
||||
# "accelerate",
|
||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets[audio] >= 1.12.0",
|
||||
# "torch >= 1.5",
|
||||
# "torchaudio",
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets[audio] >= 1.18.0",
|
||||
# "torch >= 1.5",
|
||||
# "torchaudio",
|
||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets[audio] >= 1.18.0",
|
||||
# "torch >= 1.5",
|
||||
# "torchaudio",
|
||||
@ -64,7 +64,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "datasets[audio] >= 1.18.0",
|
||||
# "torch >= 1.5",
|
||||
# "torchaudio",
|
||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.21.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
# "protobuf",
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.21.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
# "protobuf",
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "seqeval",
|
||||
# "datasets >= 1.8.0",
|
||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "seqeval",
|
||||
# "datasets >= 1.8.0",
|
||||
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -66,7 +66,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "transformers @ git+https://github.com/huggingface/transformers.git",
|
||||
# "transformers==4.55.4",
|
||||
# "accelerate >= 0.12.0",
|
||||
# "datasets >= 1.8.0",
|
||||
# "sentencepiece != 0.1.92",
|
||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.55.0.dev0")
|
||||
check_min_version("4.55.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -463,7 +463,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.55.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.55.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.55.0.dev0"
|
||||
__version__ = "4.55.4"
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -677,24 +677,6 @@ class GenerationMixin(ContinuousMixin):
|
||||
if encoder_attention_mask is not None:
|
||||
model_inputs["attention_mask"] = encoder_attention_mask
|
||||
|
||||
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
|
||||
tensor_kws = {"dtype": torch.int32, "device": self.device}
|
||||
pos = model_inputs["position_ids"][:, -1]
|
||||
|
||||
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
|
||||
max_length_k = int(pos.max()) + 1
|
||||
|
||||
bs, seq_len = input_ids.size()
|
||||
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
|
||||
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
|
||||
max_length_q = int(q_len.max())
|
||||
|
||||
model_inputs.update(
|
||||
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
|
||||
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
|
||||
max_length_q=max_length_q,
|
||||
max_length_k=max_length_k,
|
||||
)
|
||||
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
||||
for key, value in kwargs.items():
|
||||
if key not in model_inputs:
|
||||
@ -1816,7 +1798,8 @@ class GenerationMixin(ContinuousMixin):
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
cache = model_kwargs["past_key_values"]
|
||||
past_length = 0
|
||||
if not isinstance(cache, Cache):
|
||||
# Support for BC tuple cache format
|
||||
if isinstance(cache, tuple):
|
||||
past_length = cache[0][0].shape[2]
|
||||
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
||||
past_length = cache.get_seq_length()
|
||||
|
@ -49,7 +49,7 @@ FP4_VALUES = [
|
||||
|
||||
# Copied from GPT_OSS repo and vllm
|
||||
def quantize_to_mxfp4(w):
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
|
||||
downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
|
||||
|
||||
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
|
||||
w, w_scale = swizzle_mxfp4(w, w_scale)
|
||||
@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
|
||||
|
||||
|
||||
def swizzle_mxfp4(w, w_scale):
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
FP4, convert_layout, wrap_torch_tensor = (
|
||||
triton_kernels_hub.tensor.FP4,
|
||||
triton_kernels_hub.tensor.convert_layout,
|
||||
triton_kernels_hub.tensor.wrap_torch_tensor,
|
||||
)
|
||||
layout = triton_kernels_hub.tensor_details.layout
|
||||
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
|
||||
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
|
||||
@ -168,16 +172,20 @@ class Mxfp4GptOssExperts(nn.Module):
|
||||
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
self.alpha = 1.702
|
||||
|
||||
self.limit = getattr(config, "swiglu_limit", 7.0)
|
||||
self.gate_up_proj_precision_config = None
|
||||
self.down_proj_precision_config = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
|
||||
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
FnSpecs, FusedActivation, matmul_ogs = (
|
||||
triton_kernels_hub.matmul_ogs.FnSpecs,
|
||||
triton_kernels_hub.matmul_ogs.FusedActivation,
|
||||
triton_kernels_hub.matmul_ogs.matmul_ogs,
|
||||
)
|
||||
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
|
||||
|
||||
with torch.cuda.device(hidden_states.device):
|
||||
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
|
||||
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
@ -211,7 +219,12 @@ def routing_torch_dist(
|
||||
):
|
||||
import os
|
||||
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
|
||||
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
|
||||
triton_kernels_hub.routing.GatherIndx,
|
||||
triton_kernels_hub.routing.RoutingData,
|
||||
triton_kernels_hub.routing.ScatterIndx,
|
||||
triton_kernels_hub.routing.compute_expt_data_torch,
|
||||
)
|
||||
|
||||
with torch.cuda.device(logits.device):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
@ -274,13 +287,16 @@ def mlp_forward(self, hidden_states):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
routing = routing_torch_dist
|
||||
else:
|
||||
from triton_kernels.routing import routing
|
||||
routing = triton_kernels_hub.routing.routing
|
||||
|
||||
routing = routing
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
|
||||
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
|
||||
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
||||
|
||||
with torch.cuda.device(router_logits.device):
|
||||
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
||||
|
||||
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
|
||||
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
|
||||
return routed_out, router_logits
|
||||
@ -334,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
|
||||
|
||||
|
||||
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
|
||||
|
||||
PrecisionConfig, FlexCtx, InFlexData = (
|
||||
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
||||
triton_kernels_hub.matmul_ogs.FlexCtx,
|
||||
triton_kernels_hub.matmul_ogs.InFlexData,
|
||||
)
|
||||
from ..integrations.tensor_parallel import shard_and_distribute_module
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
@ -447,6 +466,11 @@ def replace_with_mxfp4_linear(
|
||||
):
|
||||
if quantization_config.dequantize:
|
||||
return model
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
global triton_kernels_hub
|
||||
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
|
@ -10,20 +10,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.import_utils import is_torch_npu_available
|
||||
|
||||
|
||||
if is_torch_npu_available():
|
||||
import math
|
||||
|
||||
import torch_npu
|
||||
from einops import rearrange, repeat
|
||||
from torch_npu import npu_rotary_mul
|
||||
from torch_npu import npu_fusion_attention
|
||||
|
||||
|
||||
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
|
||||
@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
|
||||
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
|
||||
|
||||
|
||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
class IndexFirstAxis(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
ctx.save_for_backward(indices)
|
||||
assert input.ndim >= 2
|
||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
# return input[indices]
|
||||
return torch.gather(
|
||||
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
||||
).reshape(-1, *other_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
assert grad_output.ndim >= 2
|
||||
other_shape = grad_output.shape[1:]
|
||||
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
||||
grad_input = torch.zeros(
|
||||
[ctx.first_axis_dim, grad_output.shape[1]],
|
||||
device=grad_output.device,
|
||||
dtype=grad_output.dtype,
|
||||
)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
# grad_input[indices] = grad_output
|
||||
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
||||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
||||
|
||||
|
||||
index_first_axis = IndexFirstAxis.apply
|
||||
|
||||
|
||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
class IndexPutFirstAxis(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values, indices, first_axis_dim):
|
||||
ctx.save_for_backward(indices)
|
||||
assert indices.ndim == 1
|
||||
assert values.ndim >= 2
|
||||
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
output[indices] = values
|
||||
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
grad_values = grad_output[indices]
|
||||
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
||||
return grad_values, None, None
|
||||
|
||||
|
||||
index_put_first_axis = IndexPutFirstAxis.apply
|
||||
|
||||
|
||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
def pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
batch: int, batch size for the padded sequence.
|
||||
seqlen: int, maximum sequence length for the padded sequence.
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
# dim = hidden_states.shape[-1]
|
||||
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
# output[indices] = hidden_states
|
||||
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
||||
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||
|
||||
|
||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||||
"""
|
||||
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||||
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||||
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
||||
# so we write custom forward and backward to make it a bit faster.
|
||||
return (
|
||||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
used_seqlens_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def npu_flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
@ -179,11 +64,11 @@ def npu_flash_attn_func(
|
||||
|
||||
if not causal:
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||||
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||||
else:
|
||||
attn_mask_npu = get_attn_mask_npu(q.device)
|
||||
head_num = q.shape[2]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
output = npu_fusion_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
|
||||
|
||||
if not causal:
|
||||
head_num = q.shape[1]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
output = npu_fusion_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
|
||||
else:
|
||||
attn_mask_npu = get_attn_mask_npu(q.device)
|
||||
head_num = q.shape[1]
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
output = npu_fusion_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -251,19 +136,3 @@ def npu_flash_attn_varlen_func(
|
||||
)[0]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
|
||||
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||||
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
|
||||
cos = cos.repeat(1, 2)
|
||||
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||||
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
|
||||
sin = sin.repeat(1, 2)
|
||||
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
return npu_rotary_mul(x, cos, sin)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,17 +14,15 @@
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.utils.import_utils import is_kernels_available
|
||||
|
||||
from .utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_npu_available,
|
||||
logging,
|
||||
@ -34,18 +32,139 @@ from .utils import (
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
|
||||
return reshaped[indices]
|
||||
# TODO Deprecate when all models have the attention interface
|
||||
def flash_attn_supports_top_left_mask():
|
||||
if is_flash_attn_3_available():
|
||||
return False
|
||||
if is_flash_attn_2_available():
|
||||
return not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
||||
|
||||
return is_npu_fa2_top_left_aligned_causal_mask()
|
||||
|
||||
|
||||
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
# TODO Deprecate when all models have the attention interface
|
||||
def is_flash_attn_available():
|
||||
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
|
||||
|
||||
|
||||
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
|
||||
_flash_fn = None
|
||||
_flash_varlen_fn = None
|
||||
_pad_fn = None
|
||||
_unpad_fn = None
|
||||
|
||||
# function that processes kwargs, generalized to handle any supported kwarg within the function
|
||||
_process_flash_kwargs_fn = None
|
||||
# exceptions where hf API doesn't match the original flash attention API
|
||||
_hf_api_to_flash_mapping = {
|
||||
"dropout": "dropout_p",
|
||||
"sliding_window": "window_size",
|
||||
}
|
||||
|
||||
|
||||
def _lazy_imports(implementation: Optional[str]):
|
||||
"""
|
||||
FA3-compatible unpad_input function.
|
||||
Lazy loads the respective flash attention implementations.
|
||||
|
||||
Return:
|
||||
flash_attn_func: The base flash attention function.
|
||||
flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
|
||||
e.g. for padding-free training.
|
||||
pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
|
||||
unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
|
||||
"""
|
||||
is_fa2 = is_flash_attn_2_available()
|
||||
is_fa3 = is_flash_attn_3_available()
|
||||
|
||||
pad_input, unpad_input = _pad_input, _unpad_input
|
||||
|
||||
if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
elif is_torch_npu_available():
|
||||
# Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
|
||||
# Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
|
||||
else:
|
||||
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
|
||||
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
# Kernels fallback
|
||||
else:
|
||||
flash_attn_func = getattr(implementation, "flash_attn_func", None)
|
||||
flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
|
||||
if flash_attn_varlen_func is None or flash_attn_func is None:
|
||||
raise ValueError(
|
||||
f"Could not find the currently requested flash attention implementation at `{implementation}`."
|
||||
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
|
||||
)
|
||||
|
||||
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
|
||||
|
||||
|
||||
def _lazy_define_process_function(flash_function):
|
||||
"""
|
||||
Depending on the version and kernel some features are not supported. Due to limitations in
|
||||
`torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
|
||||
within `_process_flash_attention_kwargs`.
|
||||
|
||||
NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
|
||||
This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
|
||||
"""
|
||||
global _process_flash_kwargs_fn, _hf_api_to_flash_mapping
|
||||
|
||||
flash_parameters = inspect.signature(flash_function).parameters
|
||||
process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
|
||||
|
||||
supports_mapping = {}
|
||||
for param in process_parameters:
|
||||
fa_param = _hf_api_to_flash_mapping.get(param, param)
|
||||
supports_mapping[fa_param] = fa_param in flash_parameters
|
||||
|
||||
return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
|
||||
|
||||
|
||||
def lazy_import_flash_attention(implementation: Optional[str]):
|
||||
"""
|
||||
Lazy loading flash attention and returning the respective functions + flags back
|
||||
|
||||
NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
|
||||
can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
|
||||
"""
|
||||
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
|
||||
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
|
||||
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
|
||||
|
||||
global _process_flash_kwargs_fn
|
||||
if _process_flash_kwargs_fn is None:
|
||||
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
|
||||
|
||||
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
|
||||
|
||||
|
||||
def _index_first_axis(tensor, indices):
|
||||
"""
|
||||
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
|
||||
after flattening the first two dimensions of the tensor. This is functionally equivalent to
|
||||
FA2's `index_first_axis` and replaces the need to import it.
|
||||
"""
|
||||
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
|
||||
# two dimensions to get (total_tokens, ...) before indexing.
|
||||
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
|
||||
return reshaped_tensor[indices]
|
||||
|
||||
|
||||
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||
@ -69,14 +188,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
)
|
||||
|
||||
|
||||
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||
def _pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
FA3-compatible pad_input function.
|
||||
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
batch: int, batch size for the padded sequence.
|
||||
seqlen: int, maximum sequence length for the padded sequence.
|
||||
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
@ -89,9 +210,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
|
||||
Return:
|
||||
indices (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input sequence.
|
||||
@ -125,6 +248,7 @@ def _upad_input(
|
||||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
||||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
|
||||
tensors for query, key, value tensors.
|
||||
|
||||
Arguments:
|
||||
query_layer (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
@ -138,6 +262,7 @@ def _upad_input(
|
||||
Target length.
|
||||
unpad_input_func:
|
||||
The function to use for unpadding the input tensors.
|
||||
|
||||
Return:
|
||||
query_layer (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
@ -190,12 +315,79 @@ def _upad_input(
|
||||
)
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
|
||||
"""
|
||||
This function returns all the necessary kwargs to call `flash_attn_varlen_func`
|
||||
extracted from position_ids. The `position_ids` can be either packed sequence or
|
||||
the usual padded position ids, for example in inference time.
|
||||
|
||||
Arguments:
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
is_packed_sequence (`bool`, *optional*, defaults to `True`):
|
||||
Whether the input position ids are a packed sequence or not.
|
||||
|
||||
Return:
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
|
||||
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
|
||||
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
# If the lengths are not equal, most probably we are in decoding stage with cache
|
||||
# In that case the position ids will not always start with `0` and we need a better way to infer
|
||||
# cumulative seq lengths.
|
||||
if not is_packed_sequence:
|
||||
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||
|
||||
last_position_ids = position_ids[:, -1]
|
||||
q_len = (
|
||||
torch.ones(position_ids.size(0), **tensor_kwargs)
|
||||
if position_ids.shape[-1] == 1
|
||||
else last_position_ids.add(1)
|
||||
)
|
||||
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0)
|
||||
cu_seq_lens_k = torch.cat(
|
||||
[torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0
|
||||
)
|
||||
|
||||
max_length_q = int(q_len.max())
|
||||
max_length_k = int(last_position_ids.max()) + 1
|
||||
else:
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens_q = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
)
|
||||
cu_seq_lens_k = cu_seq_lens_q
|
||||
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length_q = cu_seq_lens_q.diff().max()
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
max_length_q = max_length_q.item()
|
||||
max_length_k = max_length_q
|
||||
|
||||
return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids, query_length):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
@ -205,6 +397,9 @@ def _prepare_from_posids(query, key, value, position_ids):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
query_length (`int`):
|
||||
Sequence length of the input queries.
|
||||
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
@ -219,123 +414,152 @@ def _prepare_from_posids(query, key, value, position_ids):
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
kv_length = key.shape[1]
|
||||
is_packed_sequence = query_length == kv_length
|
||||
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
(cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
|
||||
position_ids, is_packed_sequence=is_packed_sequence
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
|
||||
|
||||
|
||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
warnings.warn(
|
||||
"prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
|
||||
"The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return _prepare_from_posids(query, key, value, position_ids)
|
||||
|
||||
|
||||
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
|
||||
def _is_packed_sequence(position_ids, batch_size):
|
||||
"""
|
||||
Check the position ids whether packed sequences are indicated or not
|
||||
1. Position ids exist
|
||||
2. Flattened sequences only are supported
|
||||
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
|
||||
"""
|
||||
if position_ids is None:
|
||||
return False
|
||||
|
||||
increasing_position_sequences = (
|
||||
torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
|
||||
)
|
||||
return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
|
||||
|
||||
|
||||
def fa_peft_integration_check(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
PEFT usually casts the layer norms in float32 for training stability reasons
|
||||
therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
||||
"""
|
||||
if target_dtype and q.dtype == torch.float32:
|
||||
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
|
||||
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def _lazy_imports(impl: Optional[str]):
|
||||
# returns funcs and pad/unpad based on impl
|
||||
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
|
||||
is_fa3 = is_flash_attn_3_available()
|
||||
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
|
||||
|
||||
except ImportError as e:
|
||||
if not globals().get("use_remote_fa2", None):
|
||||
use_remote_fa2 = (
|
||||
input(
|
||||
"Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
|
||||
if globals()["use_remote_fa2"]:
|
||||
if not is_kernels_available():
|
||||
raise ImportError("You need to install kernels: `pip install kernels`")
|
||||
from kernels import get_kernel
|
||||
|
||||
impl = get_kernel("kernels-community/flash-attn")
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return (
|
||||
getattr(impl, "flash_attn_func", None),
|
||||
getattr(impl, "flash_attn_varlen_func"),
|
||||
pad_input,
|
||||
unpad_input,
|
||||
True,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ImportError(
|
||||
"Failed to import flash attention 2, please install it or use another implementation."
|
||||
) from e
|
||||
if impl == "flash_attention_3" or (impl is None and is_fa3):
|
||||
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
|
||||
else:
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
return (
|
||||
getattr(impl, "flash_attn_func", None),
|
||||
getattr(impl, "flash_attn_varlen_func"),
|
||||
pad_input,
|
||||
unpad_input,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
_flash_supports_window = None
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
|
||||
|
||||
|
||||
def flash_attn_supports_top_left_mask():
|
||||
if is_flash_attn_3_available():
|
||||
return False
|
||||
if is_flash_attn_2_available():
|
||||
return not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
|
||||
|
||||
return is_npu_fa2_top_left_aligned_causal_mask()
|
||||
|
||||
|
||||
class FlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for Flash Attention with Compile.
|
||||
|
||||
Attributes:
|
||||
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
|
||||
Gets cumulative sequence length for query state.
|
||||
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`, *optional*):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`, *optional*):
|
||||
Maximum sequence length for key state.
|
||||
"""
|
||||
|
||||
cumulative_seqlens_q: Optional[torch.LongTensor]
|
||||
cumulative_seqlens_k: Optional[torch.LongTensor]
|
||||
max_length_q: Optional[int]
|
||||
max_length_k: Optional[int]
|
||||
|
||||
|
||||
def _process_flash_attention_kwargs(
|
||||
query_length: int,
|
||||
key_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: Optional[bool] = None,
|
||||
s_aux: Optional[torch.Tensor] = None,
|
||||
supports_mapping: Optional[dict[str, bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a set of kwargs that are passed down to the according flash attention function based on
|
||||
requested features and whether it is supported - depends on the version and kernel implementation
|
||||
which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be
|
||||
inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
|
||||
|
||||
Args:
|
||||
query_length (`int`):
|
||||
Length of the query states
|
||||
key_length (`int`):
|
||||
Length of the key states
|
||||
is_causal (`bool`):
|
||||
Whether we perform causal (decoder) attention or full attention.
|
||||
dropout (`float`):
|
||||
Attention dropout.
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
|
||||
sliding_window (`int`, *optional*):
|
||||
The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
|
||||
use_top_left_mask (`bool`):
|
||||
Deprecated behavior of older versions of flash attention requiring different masking.
|
||||
softcap (`float`, *optional*):
|
||||
Softcap for the attention logits, used e.g. in gemma2.
|
||||
deterministic (`bool`, *optional*):
|
||||
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
|
||||
s_aux (`torch.Tensor`, *optional*):
|
||||
Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
|
||||
Return:
|
||||
flash_kwargs (`dict`):
|
||||
A dict of kwargs that are requested and supported.
|
||||
"""
|
||||
flash_kwargs = {
|
||||
"causal": is_causal and not (use_top_left_mask and query_length == 1),
|
||||
"softmax_scale": softmax_scale,
|
||||
}
|
||||
|
||||
if supports_mapping["dropout_p"]:
|
||||
flash_kwargs["dropout_p"] = dropout
|
||||
|
||||
if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
|
||||
flash_kwargs["window_size"] = (sliding_window, sliding_window)
|
||||
|
||||
if supports_mapping["deterministic"]:
|
||||
flash_kwargs["deterministic"] = (
|
||||
deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
)
|
||||
|
||||
if supports_mapping["softcap"] and softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
|
||||
# Only within kernel implementation atm
|
||||
if supports_mapping["s_aux"] and s_aux is not None:
|
||||
flash_kwargs["s_aux"] = s_aux
|
||||
|
||||
return flash_kwargs
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
@ -360,100 +584,121 @@ def _flash_attention_forward(
|
||||
implementation: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
|
||||
flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
|
||||
globals()["_flash_fn"] = flash_fn
|
||||
globals()["_flash_varlen_fn"] = flash_varlen_fn
|
||||
globals()["_pad_fn"] = pad_fn
|
||||
globals()["_unpad_fn"] = unpad_fn
|
||||
globals()["_is_fa3"] = is_fa3
|
||||
flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
|
||||
globals()["_flash_supports_window"] = flash_supports_window
|
||||
else:
|
||||
flash_fn = globals()["_flash_fn"]
|
||||
flash_varlen_fn = globals()["_flash_varlen_fn"]
|
||||
pad_fn = globals()["_pad_fn"]
|
||||
unpad_fn = globals()["_unpad_fn"]
|
||||
is_fa3 = globals()["_is_fa3"]
|
||||
flash_supports_window = globals()["_flash_supports_window"]
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
causal = is_causal and not (use_top_left_mask and query_length == 1)
|
||||
use_sw = (
|
||||
(_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
|
||||
(Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
implementation (`str`, *optional*):
|
||||
The attention implementation to use. If None, will default to the one based on the environment.
|
||||
"""
|
||||
(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
|
||||
implementation
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
|
||||
if not is_fa3:
|
||||
flash_kwargs["dropout_p"] = dropout
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = det
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
if "s_aux" in kwargs:
|
||||
flash_kwargs["s_aux"] = kwargs.get("s_aux")
|
||||
|
||||
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
|
||||
query_states, key_states, value_states = fa_peft_integration_check(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
use_mask = position_ids is not None or all(
|
||||
k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]
|
||||
|
||||
# Extract the flash attention kwargs that have been requested (and are supported by the implementation)
|
||||
flash_kwargs = process_flash_kwargs_fn(
|
||||
query_length=query_length,
|
||||
key_length=key_states.size(1),
|
||||
is_causal=is_causal,
|
||||
dropout=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
sliding_window=sliding_window,
|
||||
use_top_left_mask=use_top_left_mask,
|
||||
softcap=softcap,
|
||||
deterministic=deterministic,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
|
||||
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
|
||||
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
|
||||
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
|
||||
#
|
||||
# NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model.
|
||||
# See #39121 for more information.
|
||||
is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
|
||||
is_fa_with_varlen_kwargs = all(
|
||||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
|
||||
)
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
|
||||
q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
|
||||
)
|
||||
# TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
|
||||
|
||||
# TODO for now this is required to work with
|
||||
# https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
|
||||
if "mps" in str(q.device):
|
||||
cu_k = cu_k.clone()
|
||||
cu_seq_lens_k = cu_seq_lens_k.clone()
|
||||
|
||||
out_unpad = flash_varlen_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_q.to(torch.int32),
|
||||
cu_seqlens_k=cu_k.to(torch.int32),
|
||||
max_seqlen_q=mq,
|
||||
max_seqlen_k=mk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
cu_seqlens_q=cu_seq_lens_q,
|
||||
cu_seqlens_k=cu_seq_lens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
**flash_kwargs,
|
||||
)
|
||||
if isinstance(out_unpad, tuple):
|
||||
out_unpad = out_unpad[0]
|
||||
out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
|
||||
elif use_mask:
|
||||
|
||||
out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
|
||||
|
||||
# Padding free, i.e. sequences flattened into one total sequence
|
||||
elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
if position_ids is None:
|
||||
raise ValueError(
|
||||
"Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
|
||||
)
|
||||
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
|
||||
query_states, key_states, value_states, position_ids, query_length=query_length
|
||||
)
|
||||
else:
|
||||
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||
mq, mk = max_length_q, max_length_k
|
||||
cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
|
||||
|
||||
# TODO for now this is required to work with
|
||||
# https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
|
||||
if "mps" in str(q.device):
|
||||
cu_k = cu_k.clone()
|
||||
cu_seq_lens_k = cu_seq_lens_k.clone()
|
||||
|
||||
out = flash_varlen_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_q.to(torch.int32),
|
||||
cu_seqlens_k=cu_k.to(torch.int32),
|
||||
max_seqlen_q=mq,
|
||||
max_seqlen_k=mk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
cu_seqlens_q=cu_seq_lens_q,
|
||||
cu_seqlens_k=cu_seq_lens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
**flash_kwargs,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
|
||||
else:
|
||||
out = flash_fn(
|
||||
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
)
|
||||
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
|
||||
|
||||
# No padding
|
||||
else:
|
||||
out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
return out
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
@ -95,7 +94,7 @@ class GradientCheckpointingLayer(nn.Module):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForSequenceClassification(ABC):
|
||||
class GenericForSequenceClassification(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
@ -170,7 +169,7 @@ class GenericForSequenceClassification(ABC):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForQuestionAnswering(ABC):
|
||||
class GenericForQuestionAnswering(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
@ -231,7 +230,7 @@ class GenericForQuestionAnswering(ABC):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GenericForTokenClassification(ABC):
|
||||
class GenericForTokenClassification(object):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -74,6 +74,7 @@ from .integrations.tensor_parallel import (
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
|
||||
from .modeling_flash_attention_utils import lazy_import_flash_attention
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
_pp_plan = None
|
||||
|
||||
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
||||
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
||||
# In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
|
||||
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
||||
_supports_attention_backend = False
|
||||
_can_record_outputs = None
|
||||
@ -2482,6 +2483,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if not is_flash_attn_2_available():
|
||||
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
|
||||
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
|
||||
if is_torch_npu_available():
|
||||
logger.info("Detect using FlashAttention2 on Ascend NPU.")
|
||||
return True
|
||||
|
||||
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
|
||||
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
|
||||
@ -2740,6 +2746,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if attention_wrapper is None:
|
||||
attention_wrapper = flash_attention_forward
|
||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||
lazy_import_flash_attention(kernel)
|
||||
elif kernel_name is not None:
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
|
||||
@ -2755,7 +2762,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
return attn_implementation
|
||||
else:
|
||||
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||
attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||
|
||||
# preload flash attention here to allow compile with fullgraph
|
||||
if applicable_attn_implementation.startswith("flash_attention"):
|
||||
lazy_import_flash_attention(applicable_attn_implementation)
|
||||
|
||||
return attn_implementation
|
||||
|
||||
def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
|
||||
requested_attention = "sdpa" if _requested_attention is None else _requested_attention
|
||||
|
@ -1,269 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
Aimv2Config,
|
||||
Aimv2Model,
|
||||
Aimv2VisionConfig,
|
||||
Aimv2VisionModel,
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
|
||||
# Embeddings
|
||||
r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
|
||||
r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
|
||||
r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
|
||||
# Encoder Layers
|
||||
r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
|
||||
r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
|
||||
r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
|
||||
# Normalization Layers
|
||||
r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
|
||||
r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
|
||||
# Final Norm
|
||||
r"trunk.post_trunk_norm": r"rms_norm",
|
||||
}
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
# Vision Embeddings
|
||||
r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
|
||||
r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
|
||||
r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
|
||||
# Vision Encoder Layers
|
||||
r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
|
||||
r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
|
||||
r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
|
||||
# Normalization Layers
|
||||
r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
|
||||
r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
|
||||
r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
|
||||
r"image_projector": r"visual_projection",
|
||||
# Vision Head
|
||||
r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
|
||||
r"image_encoder.head.k": r"vision_model.head.k_proj",
|
||||
r"image_encoder.head.v": r"vision_model.head.v_proj",
|
||||
r"image_encoder.head.linear": r"vision_model.head.output_proj",
|
||||
# Text Embeddings
|
||||
r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
|
||||
r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
|
||||
# Text Encoder Layers
|
||||
r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
|
||||
r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
|
||||
r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
|
||||
# Text Normalization Layers
|
||||
r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
|
||||
r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
|
||||
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
|
||||
r"text_projector": r"text_projection",
|
||||
r"log_logit_scale": r"logit_scale",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
|
||||
# Download only the model.safetensors file
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["model.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
safetensor_path = f"{directory_path}/model.safetensors"
|
||||
|
||||
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
|
||||
"""Converts state dict keys from the old format to the new format."""
|
||||
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def split_qkv_tensor(key, tensor):
|
||||
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
|
||||
|
||||
new_keys = ["q_proj", "k_proj", "v_proj"]
|
||||
split_size = tensor.shape[0] // 3
|
||||
split_tensors = torch.split(tensor, split_size, dim=0)
|
||||
|
||||
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
|
||||
|
||||
|
||||
def get_model_config_mapping(model_id: str):
|
||||
"""Determines the correct model, config, and key mappings based on the checkpoint name."""
|
||||
|
||||
if model_id == "apple/aimv2-large-patch14-224-lit":
|
||||
return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
|
||||
else:
|
||||
return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
|
||||
|
||||
|
||||
def write_model(
|
||||
hf_repo_id: str,
|
||||
output_dir: str,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a model checkpoint to Hugging Face format and saves it.
|
||||
|
||||
Args:
|
||||
hf_repo_id (str): The Hugging Face repo ID to load from.
|
||||
output_dir (str): The directory to save the converted model.
|
||||
safe_serialization (bool): Whether to use safe serialization.
|
||||
|
||||
Returns:
|
||||
model: The reloaded Hugging Face model.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Get the appropriate model, config, and key mapping
|
||||
model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
|
||||
|
||||
# Load config and original state dict
|
||||
config = config_class.from_pretrained(hf_repo_id)
|
||||
|
||||
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
|
||||
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
|
||||
config.use_head = False
|
||||
|
||||
if hf_repo_id == "apple/aimv2-large-patch14-native":
|
||||
config.is_native = True
|
||||
|
||||
original_state_dict = load_original_state_dict(hf_repo_id)
|
||||
|
||||
print("Converting model...")
|
||||
|
||||
state_dict = {}
|
||||
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
|
||||
all_keys = list(original_state_dict.keys())
|
||||
|
||||
for key in all_keys:
|
||||
value = original_state_dict[key]
|
||||
new_key = result.pop(key)
|
||||
|
||||
if "qkv" in new_key:
|
||||
qkv_state_dict = split_qkv_tensor(new_key, value)
|
||||
state_dict.update(qkv_state_dict)
|
||||
else:
|
||||
state_dict[new_key] = value
|
||||
|
||||
# Check if position embeddings exist before squeezing
|
||||
if new_key.endswith("position_embedding.weight"):
|
||||
state_dict[new_key] = value.squeeze(0)
|
||||
|
||||
print(f"Loading the checkpoint in a {model_class.__name__}.")
|
||||
model = model_class(config)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
gc.collect()
|
||||
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
model = model_class.from_pretrained(output_dir, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
return model
|
||||
|
||||
|
||||
def write_image_processor(hf_repo_id: str, output_dir: str):
|
||||
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
|
||||
image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
|
||||
else:
|
||||
image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
|
||||
image_processor.save_pretrained(output_dir)
|
||||
return image_processor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--hf_repo_id",
|
||||
default="apple/aimv2-large-patch14-224",
|
||||
help="Location of official weights from apple on HF",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="aimv2_model",
|
||||
help="Location to write the converted model and processor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether or not to push the converted model to the huggingface hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_repo_id",
|
||||
default=None,
|
||||
help="Huggingface hub repo to write the converted model and processor",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = write_model(
|
||||
hf_repo_id=args.hf_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
)
|
||||
|
||||
image_processor = write_image_processor(
|
||||
hf_repo_id=args.hf_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
print("Pushing to hub...")
|
||||
model.push_to_hub(args.hub_repo_id)
|
||||
image_processor.push_to_hub(args.hub_repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = AlbertConfig.from_json_file(albert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = AlbertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
@ -1,389 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALIGN checkpoints from the original repository."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import align
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from transformers import (
|
||||
AlignConfig,
|
||||
AlignModel,
|
||||
AlignProcessor,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
EfficientNetConfig,
|
||||
EfficientNetImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
image = tf.image.resize(image, (346, 346))
|
||||
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
|
||||
return image
|
||||
|
||||
|
||||
def get_align_config():
|
||||
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
|
||||
vision_config.image_size = 289
|
||||
vision_config.hidden_dim = 640
|
||||
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
|
||||
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
|
||||
vision_config.depthwise_padding = []
|
||||
|
||||
text_config = BertConfig()
|
||||
config = AlignConfig.from_text_vision_configs(
|
||||
text_config=text_config, vision_config=vision_config, projection_dim=640
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def get_processor():
|
||||
image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
rescale_offset=True,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
tokenizer.model_max_length = 64
|
||||
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def rename_keys(original_param_names):
|
||||
# EfficientNet image encoder
|
||||
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
|
||||
block_names = list(set(block_names))
|
||||
block_names = sorted(block_names)
|
||||
num_blocks = len(block_names)
|
||||
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
|
||||
|
||||
rename_keys = []
|
||||
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
|
||||
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
|
||||
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
|
||||
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
|
||||
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
|
||||
|
||||
for b in block_names:
|
||||
hf_b = block_name_mapping[b]
|
||||
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
|
||||
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
|
||||
)
|
||||
|
||||
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
|
||||
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
|
||||
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
|
||||
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
|
||||
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
|
||||
)
|
||||
|
||||
key_mapping = {}
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = "vision_model." + item[1]
|
||||
|
||||
# BERT text encoder
|
||||
rename_keys = []
|
||||
old = "tf_bert_model/bert"
|
||||
new = "text_model"
|
||||
for i in range(12):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
|
||||
)
|
||||
|
||||
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
|
||||
)
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
|
||||
|
||||
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
|
||||
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
|
||||
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("temperature:0", "temperature"))
|
||||
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = item[1]
|
||||
return key_mapping
|
||||
|
||||
|
||||
def replace_params(hf_params, tf_params, key_mapping):
|
||||
list(hf_params.keys())
|
||||
|
||||
for key, value in tf_params.items():
|
||||
if key not in key_mapping:
|
||||
continue
|
||||
|
||||
hf_key = key_mapping[key]
|
||||
if "_conv" in key and "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
|
||||
elif "embeddings" in key:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
elif "depthwise_kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
|
||||
elif "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" in key or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
# Replace HF parameters with original TF model parameters
|
||||
hf_params[hf_key].copy_(new_hf_value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ALIGN structure.
|
||||
"""
|
||||
# Load original model
|
||||
seq_length = 64
|
||||
tok = Tokenizer(seq_length)
|
||||
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
|
||||
original_model.compile()
|
||||
original_model.load_weights(checkpoint_path)
|
||||
|
||||
tf_params = original_model.trainable_variables
|
||||
tf_non_train_params = original_model.non_trainable_variables
|
||||
tf_params = {param.name: param.numpy() for param in tf_params}
|
||||
for param in tf_non_train_params:
|
||||
tf_params[param.name] = param.numpy()
|
||||
tf_param_names = list(tf_params.keys())
|
||||
|
||||
# Load HuggingFace model
|
||||
config = get_align_config()
|
||||
hf_model = AlignModel(config).eval()
|
||||
hf_params = hf_model.state_dict()
|
||||
|
||||
# Create src-to-dst parameter name mapping dictionary
|
||||
print("Converting parameters...")
|
||||
key_mapping = rename_keys(tf_param_names)
|
||||
replace_params(hf_params, tf_params, key_mapping)
|
||||
|
||||
# Initialize processor
|
||||
processor = get_processor()
|
||||
inputs = processor(
|
||||
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
|
||||
)
|
||||
|
||||
# HF model inference
|
||||
hf_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
|
||||
hf_image_features = outputs.image_embeds.detach().numpy()
|
||||
hf_text_features = outputs.text_embeds.detach().numpy()
|
||||
|
||||
# Original model inference
|
||||
original_model.trainable = False
|
||||
tf_image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
|
||||
text = tok(tf.constant(["A picture of a cat"]))
|
||||
|
||||
image_features = original_model.image_encoder(image, training=False)
|
||||
text_features = original_model.text_encoder(text, training=False)
|
||||
|
||||
image_features = tf.nn.l2_normalize(image_features, axis=-1)
|
||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||
|
||||
# Check whether original and HF model outputs match -> np.allclose
|
||||
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||
raise ValueError("The predicted image features are not the same.")
|
||||
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||
raise ValueError("The predicted text features are not the same.")
|
||||
print("Model outputs match!")
|
||||
|
||||
if save_model:
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
# Save converted model and image processor
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Push model and image processor to hub
|
||||
print("Pushing converted ALIGN to the hub...")
|
||||
processor.push_to_hub("align-base")
|
||||
hf_model.push_to_hub("align-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="./weights/model-weights",
|
||||
type=str,
|
||||
help="Path to the pretrained TF ALIGN checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="hf_model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EPILOG_TXT = """Example:
|
||||
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
|
||||
|
||||
Example for creating the old state dict file with Python:
|
||||
|
||||
import torch
|
||||
from aria.model.language_model.aria_llama import AriaTextForCausalLM
|
||||
|
||||
# load model
|
||||
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
|
||||
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", **kwargs)
|
||||
|
||||
# load vision tower
|
||||
model.get_vision_tower().load_model()
|
||||
|
||||
# Save state dict
|
||||
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
|
||||
"""
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"vision_tower.vision_model": "vision_tower",
|
||||
"ln_ffn": "layer_norm",
|
||||
"ffn": "feed_forward",
|
||||
"ln_kv": "layer_norm_kv",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
|
||||
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_model_id,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|img|>",
|
||||
"pad_token": "<pad>",
|
||||
},
|
||||
)
|
||||
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
|
||||
processor = AriaProcessor.from_pretrained(
|
||||
text_model_id,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(text_model_id)
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_id = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
|
||||
}
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AriaForConditionalGeneration(config)
|
||||
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
# print("Saving models")
|
||||
# model.save_pretrained("local_aria", safe_serialization=False)
|
||||
# processor.save_pretrained("local_aria")
|
||||
print("Pushing to hub")
|
||||
model.push_to_hub(output_hub_path, create_pr=True)
|
||||
processor.push_to_hub(output_hub_path, create_pr=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=EPILOG_TXT,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the text model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the vision model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the converted model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_state_dict_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,279 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_audio_spectrogram_transformer_config(model_name):
|
||||
config = ASTConfig()
|
||||
|
||||
if "10-10" in model_name:
|
||||
pass
|
||||
elif "speech-commands" in model_name:
|
||||
config.max_length = 128
|
||||
elif "12-12" in model_name:
|
||||
config.time_stride = 12
|
||||
config.frequency_stride = 12
|
||||
elif "14-14" in model_name:
|
||||
config.time_stride = 14
|
||||
config.frequency_stride = 14
|
||||
elif "16-16" in model_name:
|
||||
config.time_stride = 16
|
||||
config.frequency_stride = 16
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "speech-commands" in model_name:
|
||||
config.num_labels = 35
|
||||
filename = "speech-commands-v2-id2label.json"
|
||||
else:
|
||||
config.num_labels = 527
|
||||
filename = "audioset-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "module.v" in name:
|
||||
name = name.replace("module.v", "audio_spectrogram_transformer")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "dist_token" in name:
|
||||
name = name.replace("dist_token", "embeddings.distillation_token")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
# transformer blocks
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
# final layernorm
|
||||
if "audio_spectrogram_transformer.norm" in name:
|
||||
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
|
||||
# classifier head
|
||||
if "module.mlp_head.0" in name:
|
||||
name = name.replace("module.mlp_head.0", "classifier.layernorm")
|
||||
if "module.mlp_head.1" in name:
|
||||
name = name.replace("module.mlp_head.1", "classifier.dense")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.hidden_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
|
||||
] = val[:dim, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
|
||||
] = val[dim : dim * 2, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
|
||||
] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
|
||||
] = val[:dim]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
|
||||
] = val[dim : dim * 2]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
|
||||
] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_keys(state_dict):
|
||||
ignore_keys = [
|
||||
"module.v.head.weight",
|
||||
"module.v.head.bias",
|
||||
"module.v.head_dist.weight",
|
||||
"module.v.head_dist.bias",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
|
||||
"""
|
||||
config = get_audio_spectrogram_transformer_config(model_name)
|
||||
|
||||
model_name_to_url = {
|
||||
"ast-finetuned-audioset-10-10-0.4593": (
|
||||
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.450": (
|
||||
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448": (
|
||||
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448-v2": (
|
||||
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-12-12-0.447": (
|
||||
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-14-14-0.443": (
|
||||
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-16-16-0.442": (
|
||||
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-speech-commands-v2": (
|
||||
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
|
||||
),
|
||||
}
|
||||
|
||||
# load original state_dict
|
||||
checkpoint_url = model_name_to_url[model_name]
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove some keys
|
||||
remove_keys(state_dict)
|
||||
# rename some keys
|
||||
new_state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
# load 🤗 model
|
||||
model = ASTForAudioClassification(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify outputs on dummy input
|
||||
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
|
||||
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
|
||||
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
|
||||
max_length = 1024 if "speech-commands" not in model_name else 128
|
||||
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation")
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
|
||||
filename="sample_audio.flac",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
waveform, _ = torchaudio.load(filepath)
|
||||
waveform = waveform.squeeze().numpy()
|
||||
|
||||
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
if model_name == "ast-finetuned-audioset-10-10-0.4593":
|
||||
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.450":
|
||||
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448":
|
||||
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
|
||||
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
|
||||
elif model_name == "ast-finetuned-audioset-12-12-0.447":
|
||||
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
|
||||
elif model_name == "ast-finetuned-audioset-14-14-0.443":
|
||||
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
|
||||
elif model_name == "ast-finetuned-audioset-16-16-0.442":
|
||||
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
|
||||
elif model_name == "ast-finetuned-speech-commands-v2":
|
||||
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError("Logits don't match")
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"MIT/{model_name}")
|
||||
feature_extractor.push_to_hub(f"MIT/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ast-finetuned-audioset-10-10-0.4593",
|
||||
type=str,
|
||||
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,273 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: dict) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: dict,
|
||||
save_directory: str,
|
||||
metadata: dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: dict,
|
||||
save_directory: str,
|
||||
metadata: dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba_ssm_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
save_model="sharded",
|
||||
)
|
@ -31,7 +31,7 @@ from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache, DynamicLayer
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
@ -85,7 +85,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
class HybridMambaAttentionDynamicCache(Cache):
|
||||
class HybridMambaAttentionDynamicCache:
|
||||
"""
|
||||
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
|
||||
(which has a constant shape regardless of seq_len).
|
||||
@ -104,7 +104,6 @@ class HybridMambaAttentionDynamicCache(Cache):
|
||||
is_compileable = False
|
||||
|
||||
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
|
||||
super().__init__(layer_classes=DynamicLayer)
|
||||
self.layers_block_type = config.layers_block_type
|
||||
self.has_previous_state = False # only used by mamba
|
||||
conv_kernel_size = config.mamba_d_conv
|
||||
|
@ -42,7 +42,6 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
||||
segment_sum,
|
||||
)
|
||||
|
||||
from ...cache_utils import DynamicLayer
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -114,7 +113,6 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache):
|
||||
"""
|
||||
|
||||
def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
|
||||
HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer)
|
||||
self.layers_block_type = config.layers_block_type
|
||||
self.has_previous_state = False # only used by mamba
|
||||
conv_kernel_size = config.mamba_d_conv
|
||||
|
@ -1,263 +0,0 @@
|
||||
"""Convert Bark checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from bark.generation import _load_model as _bark_load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import EncodecConfig, EncodecModel, set_seed
|
||||
from transformers.models.bark.configuration_bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
BarkFineConfig,
|
||||
BarkSemanticConfig,
|
||||
)
|
||||
from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkCoarseGenerationConfig,
|
||||
BarkFineGenerationConfig,
|
||||
BarkGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
set_seed(770)
|
||||
|
||||
|
||||
new_layer_name_dict = {
|
||||
"c_attn": "att_proj",
|
||||
"c_proj": "out_proj",
|
||||
"c_fc": "in_proj",
|
||||
"transformer.": "",
|
||||
"h.": "layers.",
|
||||
"ln_1": "layernorm_1",
|
||||
"ln_2": "layernorm_2",
|
||||
"ln_f": "layernorm_final",
|
||||
"wpe": "position_embeds_layer",
|
||||
"wte": "input_embeds_layer",
|
||||
}
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text.pt",
|
||||
},
|
||||
"coarse_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse.pt",
|
||||
},
|
||||
"fine_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine.pt",
|
||||
},
|
||||
"text": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text_2.pt",
|
||||
},
|
||||
"coarse": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse_2.pt",
|
||||
},
|
||||
"fine": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine_2.pt",
|
||||
},
|
||||
}
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
key = model_type
|
||||
if use_small:
|
||||
key += "_small"
|
||||
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if model_type == "text":
|
||||
ModelClass = BarkSemanticModel
|
||||
ConfigClass = BarkSemanticConfig
|
||||
GenerationConfigClass = BarkSemanticGenerationConfig
|
||||
elif model_type == "coarse":
|
||||
ModelClass = BarkCoarseModel
|
||||
ConfigClass = BarkCoarseConfig
|
||||
GenerationConfigClass = BarkCoarseGenerationConfig
|
||||
elif model_type == "fine":
|
||||
ModelClass = BarkFineModel
|
||||
ConfigClass = BarkFineConfig
|
||||
GenerationConfigClass = BarkFineGenerationConfig
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"])
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
|
||||
# convert Bark model arguments to HF Bark model arguments
|
||||
model_args["num_heads"] = model_args.pop("n_head")
|
||||
model_args["hidden_size"] = model_args.pop("n_embd")
|
||||
model_args["num_layers"] = model_args.pop("n_layer")
|
||||
|
||||
model_config = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(config=model_config)
|
||||
model_generation_config = GenerationConfigClass()
|
||||
|
||||
model.generation_config = model_generation_config
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k in state_dict:
|
||||
if k.startswith(unwanted_prefix):
|
||||
# replace part of the key with corresponding layer name in HF implementation
|
||||
new_k = k[len(unwanted_prefix) :]
|
||||
for old_layer_name, new_layer_name in new_layer_name_dict.items():
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name)
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.num_parameters(exclude_embeddings=True)
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
|
||||
device = "cpu" # do conversion on cpu
|
||||
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
|
||||
|
||||
# load bark initial model
|
||||
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
|
||||
|
||||
if model_type == "text":
|
||||
bark_model = bark_model["model"]
|
||||
|
||||
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
|
||||
raise ValueError("initial and new models don't have the same number of parameters")
|
||||
|
||||
# check if same output as the bark model
|
||||
batch_size = 5
|
||||
sequence_length = 10
|
||||
|
||||
if model_type in ["text", "coarse"]:
|
||||
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
|
||||
output_old_model = bark_model(vec)[0]
|
||||
|
||||
output_new_model_total = model(vec)
|
||||
|
||||
# take last logits
|
||||
output_new_model = output_new_model_total.logits[:, [-1], :]
|
||||
|
||||
else:
|
||||
prediction_codebook_channel = 3
|
||||
n_codes_total = 8
|
||||
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
|
||||
|
||||
output_new_model_total = model(prediction_codebook_channel, vec)
|
||||
output_old_model = bark_model(prediction_codebook_channel, vec)
|
||||
|
||||
output_new_model = output_new_model_total.logits
|
||||
|
||||
# output difference should come from the difference of self-attention implementation design
|
||||
if output_new_model.shape != output_old_model.shape:
|
||||
raise ValueError("initial and new outputs don't have the same shape")
|
||||
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
|
||||
raise ValueError("initial and new outputs are not equal")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_whole_bark_model(
|
||||
semantic_path,
|
||||
coarse_path,
|
||||
fine_path,
|
||||
append_text,
|
||||
hub_path,
|
||||
folder_path,
|
||||
):
|
||||
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
|
||||
|
||||
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
|
||||
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
|
||||
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
|
||||
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
semantic = BarkSemanticModel.from_pretrained(semantic_path)
|
||||
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
|
||||
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
|
||||
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
bark_config = BarkConfig.from_sub_model_configs(
|
||||
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
|
||||
)
|
||||
|
||||
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
|
||||
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
|
||||
)
|
||||
|
||||
bark = BarkModel(bark_config)
|
||||
|
||||
bark.semantic = semantic
|
||||
bark.coarse_acoustics = coarseAcoustic
|
||||
bark.fine_acoustics = fineAcoustic
|
||||
bark.codec_model = codec
|
||||
|
||||
bark.generation_config = bark_generation_config
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def load_xsum_checkpoint(checkpoint_path):
|
||||
"""Checkpoint path should end in model.pt"""
|
||||
sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
||||
hub_interface.model.load_state_dict(sd["model"])
|
||||
return hub_interface
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
||||
else:
|
||||
bart = load_xsum_checkpoint(checkpoint_path)
|
||||
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
if hf_checkpoint_name is None:
|
||||
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
if not torch.eq(tokens, tokens2).all():
|
||||
raise ValueError(
|
||||
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||
)
|
||||
|
||||
if checkpoint_path == "bart.large.mnli":
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
new_model_outputs = model(tokens)[0] # logits
|
||||
else: # no classification heads to worry about
|
||||
state_dict = bart.model.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
else:
|
||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
new_model_outputs = model.model(tokens)[0]
|
||||
|
||||
# Check results
|
||||
if fairseq_output.shape != new_model_outputs.shape:
|
||||
raise ValueError(
|
||||
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||
)
|
||||
if (fairseq_output != new_model_outputs).any().item():
|
||||
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||
)
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
@ -1,373 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BEiT checkpoints from the unilm repository."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitImageProcessor,
|
||||
)
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + shared relative position bias + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
(
|
||||
"rel_pos_bias.relative_position_bias_table",
|
||||
"beit.encoder.relative_position_bias.relative_position_bias_table",
|
||||
),
|
||||
(
|
||||
"rel_pos_bias.relative_position_index",
|
||||
"beit.encoder.relative_position_bias.relative_position_index",
|
||||
),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
# masked image modeling
|
||||
config.use_shared_relative_position_bias = True
|
||||
config.use_mask_token = True
|
||||
has_lm_head = True
|
||||
elif checkpoint_url[-9:-4] == "ft22k":
|
||||
# intermediate fine-tuning on ImageNet-22k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
elif checkpoint_url[-8:-4] == "to1k":
|
||||
# fine-tuning on ImageNet-1k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
if "384" in checkpoint_url:
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
pass
|
||||
elif "large" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
|
||||
)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size([1, 1000])
|
||||
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
|
||||
expected_class_idx = 2397
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
|
||||
expected_class_idx = 2396
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
|
||||
expected_class_idx = 285
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
|
||||
expected_class_idx = 281
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
|
||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
|
||||
|
||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
|
||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
|
||||
|
||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||
|
||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
|
||||
Models trained with never versions are not compatible with this script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
layer_depth = []
|
||||
for full_name, shape in init_vars:
|
||||
# logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
name = full_name.split("/")
|
||||
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
|
||||
logger.info(f"Skipping non-model layer {full_name}")
|
||||
continue
|
||||
if "optimizer" in full_name:
|
||||
logger.info(f"Skipping optimization layer {full_name}")
|
||||
continue
|
||||
if name[0] == "model":
|
||||
# ignore initial 'model'
|
||||
name = name[1:]
|
||||
# figure out how many levels deep the name is
|
||||
depth = 0
|
||||
for _name in name:
|
||||
if _name.startswith("layer_with_weights"):
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
layer_depth.append(depth)
|
||||
# read data
|
||||
array = tf.train.load_variable(tf_path, full_name)
|
||||
names.append("/".join(name))
|
||||
arrays.append(array)
|
||||
logger.info(f"Read a total of {len(arrays):,} layers")
|
||||
|
||||
# Sanity check
|
||||
if len(set(layer_depth)) != 1:
|
||||
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
|
||||
layer_depth = list(set(layer_depth))[0]
|
||||
if layer_depth != 1:
|
||||
raise ValueError(
|
||||
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
|
||||
" heads."
|
||||
)
|
||||
|
||||
# convert layers
|
||||
logger.info("Converting weights...")
|
||||
for full_name, array in zip(names, arrays):
|
||||
name = full_name.split("/")
|
||||
pointer = model
|
||||
trace = []
|
||||
for i, m_name in enumerate(name):
|
||||
if m_name == ".ATTRIBUTES":
|
||||
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
|
||||
break
|
||||
if m_name.startswith("layer_with_weights"):
|
||||
layer_num = int(m_name.split("-")[-1])
|
||||
if layer_num <= 2:
|
||||
# embedding layers
|
||||
# layer_num 0: word_embeddings
|
||||
# layer_num 1: position_embeddings
|
||||
# layer_num 2: token_type_embeddings
|
||||
continue
|
||||
elif layer_num == 3:
|
||||
# embedding LayerNorm
|
||||
trace.extend(["embeddings", "LayerNorm"])
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
|
||||
# encoder layers
|
||||
trace.extend(["encoder", "layer", str(layer_num - 4)])
|
||||
pointer = getattr(pointer, "encoder")
|
||||
pointer = getattr(pointer, "layer")
|
||||
pointer = pointer[layer_num - 4]
|
||||
elif layer_num == config.num_hidden_layers + 4:
|
||||
# pooler layer
|
||||
trace.extend(["pooler", "dense"])
|
||||
pointer = getattr(pointer, "pooler")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "embeddings":
|
||||
trace.append("embeddings")
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
if layer_num == 0:
|
||||
trace.append("word_embeddings")
|
||||
pointer = getattr(pointer, "word_embeddings")
|
||||
elif layer_num == 1:
|
||||
trace.append("position_embeddings")
|
||||
pointer = getattr(pointer, "position_embeddings")
|
||||
elif layer_num == 2:
|
||||
trace.append("token_type_embeddings")
|
||||
pointer = getattr(pointer, "token_type_embeddings")
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding layer with name {full_name}")
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "_attention_layer":
|
||||
# self-attention layer
|
||||
trace.extend(["attention", "self"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "self")
|
||||
elif m_name == "_attention_layer_norm":
|
||||
# output attention norm
|
||||
trace.extend(["attention", "output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_attention_output_dense":
|
||||
# output attention dense
|
||||
trace.extend(["attention", "output", "dense"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_dense":
|
||||
# output dense
|
||||
trace.extend(["output", "dense"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output dense
|
||||
trace.extend(["output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_key_dense":
|
||||
# attention key
|
||||
trace.append("key")
|
||||
pointer = getattr(pointer, "key")
|
||||
elif m_name == "_query_dense":
|
||||
# attention query
|
||||
trace.append("query")
|
||||
pointer = getattr(pointer, "query")
|
||||
elif m_name == "_value_dense":
|
||||
# attention value
|
||||
trace.append("value")
|
||||
pointer = getattr(pointer, "value")
|
||||
elif m_name == "_intermediate_dense":
|
||||
# attention intermediate dense
|
||||
trace.extend(["intermediate", "dense"])
|
||||
pointer = getattr(pointer, "intermediate")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output layer norm
|
||||
trace.append("output")
|
||||
pointer = getattr(pointer, "output")
|
||||
# weights & biases
|
||||
elif m_name in ["bias", "beta"]:
|
||||
trace.append("bias")
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif m_name in ["kernel", "gamma"]:
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
logger.warning(f"Ignored {m_name}")
|
||||
# for certain layers reshape is necessary
|
||||
trace = ".".join(trace)
|
||||
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
|
||||
r"(\S+)\.attention\.output\.dense\.weight", trace
|
||||
):
|
||||
array = array.reshape(pointer.data.shape)
|
||||
if "kernel" in full_name:
|
||||
array = array.transpose()
|
||||
if pointer.shape == array.shape:
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
|
||||
f" {array.shape}"
|
||||
)
|
||||
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
|
||||
return model
|
||||
|
||||
|
||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
|
||||
# Instantiate model
|
||||
logger.info(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from checkpoint
|
||||
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
|
||||
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
|
||||
|
||||
# Save pytorch-model
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model (must include filename).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,112 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model: BertModel Pytorch model instance to be converted
|
||||
ckpt_dir: Tensorflow model directory
|
||||
model_name: model name
|
||||
|
||||
Currently supported HF models:
|
||||
|
||||
- Y BertModel
|
||||
- N BertForMaskedLM
|
||||
- N BertForPreTraining
|
||||
- N BertForMultipleChoice
|
||||
- N BertForNextSentencePrediction
|
||||
- N BertForSequenceClassification
|
||||
- N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return f"bert/{name}"
|
||||
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
session.run(tf_var)
|
||||
return tf_var
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Session() as session:
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any(x in var_name for x in tensors_to_transpose):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
|
||||
tf_weight = session.run(tf_var)
|
||||
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
|
||||
|
||||
saver = tf.train.Saver(tf.trainable_variables())
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path, weights_only=True),
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,188 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
|
||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
|
||||
|
||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertPooler,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
|
||||
def get_masked_lm_array(name: str):
|
||||
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_array(name: str):
|
||||
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_layer_array(layer_index: int, name: str):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
array = array.reshape(original_shape)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
print(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
# Layers
|
||||
for layer_index in range(0, config.num_hidden_layers):
|
||||
layer: BertLayer = model.bert.encoder.layer[layer_index]
|
||||
|
||||
# Self-attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
|
||||
self_attn.query.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
|
||||
)
|
||||
self_attn.query.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
|
||||
)
|
||||
self_attn.key.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
|
||||
)
|
||||
self_attn.key.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
|
||||
)
|
||||
self_attn.value.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
|
||||
)
|
||||
self_attn.value.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
|
||||
)
|
||||
|
||||
# Self-attention Output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
|
||||
self_output.dense.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
|
||||
)
|
||||
self_output.dense.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
|
||||
)
|
||||
|
||||
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
|
||||
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
|
||||
|
||||
# Intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
|
||||
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
|
||||
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
|
||||
|
||||
# Output
|
||||
bert_output: BertOutput = layer.output
|
||||
|
||||
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
|
||||
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
|
||||
|
||||
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
|
||||
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
|
||||
|
||||
# Embeddings
|
||||
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
|
||||
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
|
||||
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
|
||||
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
|
||||
|
||||
# LM Head
|
||||
lm_head = model.cls.predictions.transform
|
||||
|
||||
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
|
||||
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
|
||||
|
||||
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
|
||||
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
|
||||
|
||||
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
|
||||
|
||||
# Pooling
|
||||
model.bert.pooler = BertPooler(config=config)
|
||||
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
|
||||
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
|
||||
|
||||
# Export final model
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Integration test - should load without any errors ;)
|
||||
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
|
||||
print(new_model.eval())
|
||||
|
||||
print("Model conversion was done successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,69 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigBird checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
|
||||
# Initialise PyTorch model
|
||||
config = BigBirdConfig.from_json_file(big_bird_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
if is_trivia_qa:
|
||||
model = BigBirdForQuestionAnswering(config)
|
||||
else:
|
||||
model = BigBirdForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big_bird_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
|
||||
)
|
@ -1,169 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
|
||||
|
||||
|
||||
INIT_COMMON = [
|
||||
# tf -> hf
|
||||
("/", "."),
|
||||
("layer_", "layers."),
|
||||
("kernel", "weight"),
|
||||
("beta", "bias"),
|
||||
("gamma", "weight"),
|
||||
("pegasus", "model"),
|
||||
]
|
||||
END_COMMON = [
|
||||
(".output.dense", ".fc2"),
|
||||
("intermediate.LayerNorm", "final_layer_norm"),
|
||||
("intermediate.dense", "fc1"),
|
||||
]
|
||||
|
||||
DECODER_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.out_proj"),
|
||||
("attention.self", "self_attn"),
|
||||
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
|
||||
("attention.encdec_output.dense", "encoder_attn.out_proj"),
|
||||
("attention.encdec", "encoder_attn"),
|
||||
("key", "k_proj"),
|
||||
("value", "v_proj"),
|
||||
("query", "q_proj"),
|
||||
("decoder.LayerNorm", "decoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
REMAINING_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("embeddings.word_embeddings", "shared.weight"),
|
||||
("embeddings.position_embeddings", "embed_positions.weight"),
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.output"),
|
||||
("attention.self", "self_attn.self"),
|
||||
("encoder.LayerNorm", "encoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
KEYS_TO_IGNORE = [
|
||||
"encdec/key/bias",
|
||||
"encdec/query/bias",
|
||||
"encdec/value/bias",
|
||||
"self/key/bias",
|
||||
"self/query/bias",
|
||||
"self/value/bias",
|
||||
"encdec_output/dense/bias",
|
||||
"attention/output/dense/bias",
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k, patterns):
|
||||
for tf_name, hf_name in patterns:
|
||||
k = k.replace(tf_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
|
||||
cfg = BigBirdPegasusConfig(**config_update)
|
||||
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
|
||||
state_dict = torch_model.state_dict()
|
||||
mapping = {}
|
||||
|
||||
# separating decoder weights
|
||||
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
|
||||
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
|
||||
|
||||
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = DECODER_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(i in k for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = REMAINING_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(i in k for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
|
||||
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
|
||||
missing, extra = torch_model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if k
|
||||
not in [
|
||||
"final_logits_bias",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path) -> dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any(pat in name for pat in ignore_name)
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
config_update = {}
|
||||
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BioGptConfig, BioGptForCausalLM
|
||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
|
||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(f, "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
f"Duplicate word found when loading Dictionary: '{word}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file."
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
|
||||
# prep
|
||||
if not os.path.exists(biogpt_checkpoint_path):
|
||||
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
|
||||
chkpt = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
|
||||
|
||||
args = chkpt["cfg"]["model"]
|
||||
|
||||
# dicts
|
||||
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
|
||||
if not os.path.isfile(dict_file):
|
||||
raise ValueError(f"path to the file {dict_file} does not exist!")
|
||||
src_dict = Dictionary.load(dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
|
||||
if not os.path.isfile(bpecodes_file):
|
||||
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
|
||||
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
shutil.copyfile(bpecodes_file, merges_file)
|
||||
|
||||
# model config
|
||||
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
model_conf = {
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"architectures": ["BioGptForCausalLM"],
|
||||
"attention_probs_dropout_prob": args["attention_dropout"],
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": args["activation_fn"],
|
||||
"hidden_dropout_prob": args["dropout"],
|
||||
"hidden_size": args["decoder_embed_dim"],
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": args["decoder_ffn_embed_dim"],
|
||||
"layer_norm_eps": 1e-12,
|
||||
"layerdrop": args["decoder_layerdrop"],
|
||||
"max_position_embeddings": args["max_target_positions"],
|
||||
"model_type": "biogpt",
|
||||
"num_attention_heads": args["decoder_attention_heads"],
|
||||
"num_hidden_layers": args["decoder_layers"],
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_decoder_input_output_embed"],
|
||||
"vocab_size": src_vocab_size,
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
|
||||
print(f"Generating {biogpt_model_config_file}")
|
||||
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": "<pad>",
|
||||
"special_tokens_map_file": None,
|
||||
"tokenizer_class": "BioGptTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
print(f"Generating {biogpt_tokenizer_config_file}")
|
||||
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model_state_dict = chkpt["model"]
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"decoder.version",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
layer_names = list(model_state_dict.keys())
|
||||
for layer_name in layer_names:
|
||||
if layer_name.endswith("output_projection.weight"):
|
||||
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
|
||||
else:
|
||||
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
|
||||
|
||||
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = BioGptForCausalLM(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--biogpt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
|
||||
" bpecodes, etc."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,177 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BiT checkpoints from the timm library."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
|
||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
conv_layer = "std_conv" if "bit" in model_name else False
|
||||
|
||||
# note that when using BiT as backbone for ViT-hybrid checkpoints,
|
||||
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
|
||||
# config.conv_layer = "std_conv_same"
|
||||
config = BitConfig(
|
||||
conv_layer=conv_layer,
|
||||
num_labels=1000,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "stem.conv" in name:
|
||||
name = name.replace("stem.conv", "bit.embedder.convolution")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "layers")
|
||||
if "head.fc" in name:
|
||||
name = name.replace("head.fc", "classifier.1")
|
||||
if name.startswith("norm"):
|
||||
name = "bit." + name
|
||||
if "bit" not in name and "classifier" not in name:
|
||||
name = "bit.encoder." + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BiT structure.
|
||||
"""
|
||||
|
||||
# define default BiT configuration
|
||||
config = get_config(model_name)
|
||||
|
||||
# load original model from timm
|
||||
timm_model = create_model(model_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model
|
||||
state_dict = timm_model.state_dict()
|
||||
for key in state_dict.copy():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
|
||||
|
||||
# load HuggingFace model
|
||||
model = BitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create image processor
|
||||
transform = create_transform(**resolve_data_config({}, model=timm_model))
|
||||
timm_transforms = transform.transforms
|
||||
|
||||
pillow_resamplings = {
|
||||
"bilinear": PILImageResampling.BILINEAR,
|
||||
"bicubic": PILImageResampling.BICUBIC,
|
||||
"nearest": PILImageResampling.NEAREST,
|
||||
}
|
||||
|
||||
processor = BitImageProcessor(
|
||||
do_resize=True,
|
||||
size={"shortest_edge": timm_transforms[0].size},
|
||||
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
|
||||
do_center_crop=True,
|
||||
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
|
||||
do_normalize=True,
|
||||
image_mean=timm_transforms[-1].mean.tolist(),
|
||||
image_std=timm_transforms[-1].std.tolist(),
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
timm_pixel_values = transform(image).unsqueeze(0)
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# verify pixel values
|
||||
assert torch.allclose(timm_pixel_values, pixel_values)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Logits:", logits[0, :3])
|
||||
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model {model_name} and processor to the hub")
|
||||
model.push_to_hub(f"ybelkada/{model_name}")
|
||||
processor.push_to_hub(f"ybelkada/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="resnetv2_50x1_bitm",
|
||||
type=str,
|
||||
help="Name of the BiT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,114 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
sd = model["model"]
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -1,191 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# git clone https://github.com/salesforce/BLIP.git
|
||||
from models.blip import blip_decoder
|
||||
from models.blip_itm import blip_itm
|
||||
from models.blip_vqa import blip_vqa
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BlipConfig,
|
||||
BlipForConditionalGeneration,
|
||||
BlipForImageTextRetrieval,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
def load_demo_image(image_size, device):
|
||||
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
image = transform(raw_image).unsqueeze(0).to(device)
|
||||
return image
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
if "visual_encoder" in key:
|
||||
key = re.sub("visual_encoder*", "vision_model.encoder", key)
|
||||
if "blocks" in key:
|
||||
key = re.sub(r"blocks", "layers", key)
|
||||
if "attn" in key:
|
||||
key = re.sub(r"attn", "self_attn", key)
|
||||
if "norm1" in key:
|
||||
key = re.sub(r"norm1", "layer_norm1", key)
|
||||
if "norm2" in key:
|
||||
key = re.sub(r"norm2", "layer_norm2", key)
|
||||
if "encoder.norm" in key:
|
||||
key = re.sub(r"encoder.norm", "post_layernorm", key)
|
||||
if "encoder.patch_embed.proj" in key:
|
||||
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
|
||||
|
||||
if "encoder.pos_embed" in key:
|
||||
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
|
||||
if "encoder.cls_token" in key:
|
||||
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
|
||||
|
||||
if "self_attn" in key:
|
||||
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = BlipConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = BlipForConditionalGeneration(config).eval()
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
|
||||
|
||||
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
modified_state_dict = pt_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_model.load_state_dict(modified_state_dict)
|
||||
|
||||
image_size = 384
|
||||
image = load_demo_image(image_size=image_size, device="cpu")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
input_ids = tokenizer(["a picture of"]).input_ids
|
||||
|
||||
out = hf_model.generate(image, input_ids)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
out = hf_model.generate(image)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
|
||||
model_url = (
|
||||
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
|
||||
)
|
||||
|
||||
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
|
||||
vqa_model.eval()
|
||||
|
||||
modified_state_dict = vqa_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_vqa_model = BlipForQuestionAnswering(config)
|
||||
|
||||
hf_vqa_model.load_state_dict(modified_state_dict)
|
||||
|
||||
question = ["How many dogs are in this image?"]
|
||||
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
|
||||
|
||||
answer = hf_vqa_model.generate(question_input_ids, image)
|
||||
print(tokenizer.decode(answer[0]))
|
||||
|
||||
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
|
||||
|
||||
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
|
||||
itm_model.eval()
|
||||
|
||||
modified_state_dict = itm_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_itm_model = BlipForImageTextRetrieval(config)
|
||||
|
||||
question = ["A picture of a woman with a dog sitting in a beach"]
|
||||
question_input_ids = tokenizer(
|
||||
question,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
).input_ids
|
||||
|
||||
hf_itm_model.load_state_dict(modified_state_dict)
|
||||
hf_itm_model.eval()
|
||||
|
||||
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
|
||||
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
|
||||
|
||||
assert out[0].item() == 0.2110687494277954
|
||||
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
|
@ -1,390 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert BLIP-2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
Blip2Config,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Processor,
|
||||
Blip2QFormerConfig,
|
||||
Blip2VisionConfig,
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, model_name):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||
if "itm" in model_name:
|
||||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name, eos_token_id):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "opt-2.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "opt-6.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "itm" in model_name:
|
||||
text_config = {}
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
if "itm" in model_name:
|
||||
config = Blip2Config(
|
||||
vision_config=vision_config,
|
||||
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||
)
|
||||
else:
|
||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(
|
||||
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
if "opt" in model_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||
elif "itm" in model_name:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
|
||||
if "itm" in model_name:
|
||||
eos_token_id = None
|
||||
else:
|
||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||
|
||||
if "itm" in model_name:
|
||||
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||
else:
|
||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
|
||||
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
|
||||
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
|
||||
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
|
||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config, model_name)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "opt_proj" in key:
|
||||
key = key.replace("opt_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("opt"):
|
||||
key = key.replace("opt", "language")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert len(missing_keys) == 0
|
||||
|
||||
if "itm" in model_name:
|
||||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||
else:
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
|
||||
if "itm" in model_name:
|
||||
caption = "a large fountain spewing water into the air"
|
||||
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
|
||||
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=False,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
else:
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
hf_model.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
"blip2-itm-vit-g",
|
||||
"blip2-itm-vit-g-coco",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="blip2-opt-2.7b",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
parser.add_argument(
|
||||
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(
|
||||
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||
)
|
@ -1,254 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print(f"Processing file: {file}")
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors:
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights across TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors:
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin",
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors:
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = (
|
||||
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin"
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors:
|
||||
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights across TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors:
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
@ -1,145 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Bros checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import bros # original repo
|
||||
import torch
|
||||
|
||||
from transformers import BrosConfig, BrosModel, BrosProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_configs(model_name):
|
||||
bros_config = BrosConfig.from_pretrained(model_name)
|
||||
return bros_config
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"embeddings.bbox_sinusoid_emb.inv_freq",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if name == "embeddings.bbox_projection.weight":
|
||||
name = "bbox_embeddings.bbox_projection.weight"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
# rename keys
|
||||
for key in orig_state_dict.copy():
|
||||
val = orig_state_dict.pop(key)
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
# remove ignore keys
|
||||
remove_ignore_keys_(orig_state_dict)
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
# load original model
|
||||
original_model = bros.BrosModel.from_pretrained(model_name).eval()
|
||||
|
||||
# load HuggingFace Model
|
||||
bros_config = get_configs(model_name)
|
||||
model = BrosModel.from_pretrained(model_name, config=bros_config)
|
||||
model.eval()
|
||||
|
||||
state_dict = original_model.state_dict()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify results
|
||||
|
||||
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
|
||||
bbox = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
processor = BrosProcessor.from_pretrained(model_name)
|
||||
|
||||
encoding = processor("His name is Rocco.", return_tensors="pt")
|
||||
encoding["bbox"] = bbox
|
||||
|
||||
original_hidden_states = original_model(**encoding).last_hidden_state
|
||||
# pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
last_hidden_states = model(**encoding).last_hidden_state
|
||||
|
||||
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="jinho8345/bros-base-uncased",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Name of the original model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,59 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert T5 checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
@ -1,65 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert CANINE checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
|
||||
# Initialize PyTorch model
|
||||
config = CanineConfig()
|
||||
model = CanineModel(config)
|
||||
model.eval()
|
||||
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_canine(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model (weights and configuration)
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Save tokenizer files
|
||||
tokenizer = CanineTokenizer()
|
||||
print(f"Save tokenizer files to {pytorch_dump_path}")
|
||||
tokenizer.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a folder where the PyTorch model will be placed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
|
@ -1,478 +0,0 @@
|
||||
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate import init_empty_weights
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ChameleonConfig,
|
||||
ChameleonForConditionalGeneration,
|
||||
ChameleonImageProcessor,
|
||||
ChameleonProcessor,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
|
||||
"Update your `tokenizers` library and re-run the tokenizer conversion."
|
||||
)
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
|
||||
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
"""
|
||||
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"30B": 4,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 65536
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
input_model_path = os.path.join(input_base_path, "models", model_size.lower())
|
||||
params_path = os.path.join(input_model_path, "params.json")
|
||||
consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
|
||||
|
||||
params = read_json(params_path)
|
||||
if os.path.isfile(consolidate_params_path):
|
||||
params = {**params, **read_json(consolidate_params_path)}
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
model_parallel_size = params["model_parallel_size"]
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
swin_norm = params["swin_norm"]
|
||||
if base > 10000.0:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
# Depending on the Chameleon version, the default max_position_embeddings has different values.
|
||||
if chameleon_version == 1:
|
||||
max_position_embeddings = 4096
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Version {chameleon_version} of chameleon is not supported yet. "
|
||||
"Current supported versions of chameleon are [1]."
|
||||
)
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
||||
key_value_dim = dim // num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_local_key_value_heads = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = None
|
||||
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
|
||||
possible_path = os.path.join(input_model_path, possible_name)
|
||||
if os.path.exists(possible_path):
|
||||
loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
|
||||
break
|
||||
assert loaded is not None
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(
|
||||
os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
|
||||
)
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
# Load weights to the state dict
|
||||
state_dict = {}
|
||||
for layer_i in range(n_layers):
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
|
||||
else:
|
||||
# Sharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
}
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
n_heads=n_heads,
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
)
|
||||
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
)
|
||||
|
||||
# Load VQGAN weights
|
||||
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
|
||||
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
|
||||
for k, v in vqgan_state_dict.items():
|
||||
if "decoder" in k:
|
||||
continue # we dont do image generation yet
|
||||
state_dict[f"model.vqmodel.{k}"] = v
|
||||
|
||||
# Write configs
|
||||
ffn_dim_multiplier = params.get("ffn_dim_multiplier", 1)
|
||||
multiple_of = params.get("multiple_of", 256)
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
|
||||
tokenizer_config = json.load(tokenizer_file)
|
||||
vocabulary_map = tokenizer_config["model"]["vocab"]
|
||||
vocabulary_map["<image>"] = vocabulary_map[
|
||||
"<reserved08707>"
|
||||
] # use a reserved token instead of adding a new one
|
||||
del vocabulary_map["<reserved08707>"]
|
||||
|
||||
for token in tokenizer_config["added_tokens"]:
|
||||
if token["content"] == "<reserved08707>":
|
||||
token["content"] = "<image>"
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f) # save the new file to init tokenizer later
|
||||
|
||||
vq_keys_to_replace = [
|
||||
("ch", "base_channels"),
|
||||
("out_ch", "out_channels"),
|
||||
("n_embed", "num_embeddings"),
|
||||
("ch_mult", "channel_multiplier"),
|
||||
("double_z", "double_latent"),
|
||||
("z_channels", "latent_channels"),
|
||||
]
|
||||
with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
|
||||
vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
|
||||
vq_config.update(**vq_config["ddconfig"])
|
||||
for old, new in vq_keys_to_replace:
|
||||
vq_config[new] = vq_config[old]
|
||||
del vq_config["ddconfig"]
|
||||
del vq_config["ckpt_path"]
|
||||
del vq_config["lossconfig"]
|
||||
|
||||
config = ChameleonConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
model_parallel_size=model_parallel_size,
|
||||
swin_norm=swin_norm,
|
||||
vq_config=vq_config,
|
||||
vocabulary_map=vocabulary_map,
|
||||
)
|
||||
with init_empty_weights():
|
||||
model = ChameleonForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(state_dict, assign=True, strict=False)
|
||||
model.save_pretrained(model_path, safe_serialization=True)
|
||||
|
||||
# Load and save the processor
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
|
||||
)
|
||||
tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
|
||||
tokenizer.pad_token_id = 1 # assign <pad> to special pad_token
|
||||
image_processor = ChameleonImageProcessor()
|
||||
processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
del vqgan_state_dict
|
||||
gc.collect()
|
||||
|
||||
# Short inference on a few examples to check if generation makes sense
|
||||
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
|
||||
print("Loading the checkpoint in a Chameleon model...")
|
||||
print("*" * 100)
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained(model_path)
|
||||
|
||||
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for single-image: {generated_text}")
|
||||
print("*" * 100)
|
||||
|
||||
# Multi-image example
|
||||
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
|
||||
image = Image.open(
|
||||
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
||||
)
|
||||
image_2 = Image.open(
|
||||
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
|
||||
)
|
||||
|
||||
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for multi-image: {generated_text}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of Chameleon weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "30B"],
|
||||
help=""
|
||||
" models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, check out the original repo: https://github.com/facebookresearch/chameleon",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_inference",
|
||||
action="store_true",
|
||||
help="Whether to load the model for generation to test it's converted correctly.",
|
||||
)
|
||||
# Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--chameleon_version",
|
||||
choices=[1],
|
||||
default=1,
|
||||
type=int,
|
||||
help="Version of the Chameleon model to convert",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
model_size=args.model_size,
|
||||
chameleon_version=args.chameleon_version,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,134 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ChineseCLIPConfig, ChineseCLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
|
||||
q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
|
||||
out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight.data = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias.data = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_weights, prefix):
|
||||
copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
|
||||
copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_weights, prefix):
|
||||
hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
|
||||
hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_weights, prefix):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
|
||||
copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_weights, prefix):
|
||||
for layer_id, hf_layer in enumerate(hf_layers):
|
||||
copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
|
||||
|
||||
# copy text encoder
|
||||
for name, param in hf_model.text_model.named_parameters():
|
||||
param.data = pt_weights[f"bert.{name}"].data
|
||||
|
||||
|
||||
def copy_vision_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
|
||||
|
||||
# copy embeddings
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
|
||||
hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
|
||||
assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
|
||||
config = ChineseCLIPConfig.from_pretrained(config_path)
|
||||
|
||||
hf_model = ChineseCLIPModel(config).eval()
|
||||
|
||||
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
|
||||
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_weights)
|
||||
copy_vision_model_and_projection(hf_model, pt_weights)
|
||||
hf_model.logit_scale.data = pt_weights["logit_scale"].data
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output folder storing converted hf PyTorch model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
||||
print("The conversion is finished!")
|
@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
from laion_clap import CLAP_Module
|
||||
|
||||
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"text_branch": "text_model",
|
||||
"audio_branch": "audio_model.audio_encoder",
|
||||
"attn": "attention.self",
|
||||
"self.proj": "output.dense",
|
||||
"attention.self_mask": "attn_mask",
|
||||
"mlp.fc1": "intermediate.dense",
|
||||
"mlp.fc2": "output.dense",
|
||||
"norm1": "layernorm_before",
|
||||
"norm2": "layernorm_after",
|
||||
"bn0": "batch_norm",
|
||||
}
|
||||
|
||||
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
|
||||
|
||||
|
||||
def init_clap(checkpoint_path, model_type, enable_fusion=False):
|
||||
model = CLAP_Module(
|
||||
amodel=model_type,
|
||||
enable_fusion=enable_fusion,
|
||||
)
|
||||
model.load_ckpt(checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
def get_config_from_original(clap_model):
|
||||
audio_config = {
|
||||
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
|
||||
"depths": clap_model.model.audio_branch.depths,
|
||||
"hidden_size": clap_model.model.audio_projection[0].in_features,
|
||||
}
|
||||
|
||||
text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
|
||||
|
||||
return ClapConfig(audio_config=audio_config, text_config=text_config)
|
||||
|
||||
|
||||
def rename_state_dict(state_dict):
|
||||
model_state_dict = {}
|
||||
|
||||
sequential_layers_pattern = r".*sequential.(\d+).*"
|
||||
text_projection_pattern = r".*_projection.(\d+).*"
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# check if any key needs to be modified
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
if re.match(sequential_layers_pattern, key):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
# Because in CLAP they use `nn.Sequential`...
|
||||
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
||||
|
||||
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
||||
|
||||
if "audio" and "qkv" in key:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = value
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
|
||||
model_state_dict[key.replace("qkv", "query")] = query_layer
|
||||
model_state_dict[key.replace("qkv", "key")] = key_layer
|
||||
model_state_dict[key.replace("qkv", "value")] = value_layer
|
||||
else:
|
||||
model_state_dict[key] = value
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
|
||||
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
|
||||
|
||||
clap_model.eval()
|
||||
state_dict = clap_model.model.state_dict()
|
||||
state_dict = rename_state_dict(state_dict)
|
||||
|
||||
transformers_config = get_config_from_original(clap_model)
|
||||
transformers_config.audio_config.enable_fusion = enable_fusion
|
||||
model = ClapModel(transformers_config)
|
||||
|
||||
# ignore the spectrogram embedding layer
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
transformers_config.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
|
||||
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clap_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
|
||||
)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from clip import load
|
||||
|
||||
from transformers import CLIPConfig, CLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_attn_layer.out_proj.weight
|
||||
out_proj_bias = pt_attn_layer.out_proj.bias
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_mlp):
|
||||
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
|
||||
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
|
||||
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_layer.mlp)
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_layers):
|
||||
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
|
||||
copy_layer(hf_layer, pt_layer)
|
||||
|
||||
|
||||
def copy_encoder(hf_encoder, pt_model):
|
||||
# copy embeds
|
||||
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
|
||||
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
|
||||
|
||||
# copy layer norm
|
||||
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
|
||||
|
||||
# copy hidden layers
|
||||
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
|
||||
|
||||
# copy text encoder
|
||||
copy_encoder(hf_model.text_model, pt_model)
|
||||
|
||||
|
||||
def copy_vison_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
|
||||
|
||||
# copy embeds
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
|
||||
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = CLIPConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = CLIPModel(config).eval()
|
||||
|
||||
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_model)
|
||||
copy_vison_model_and_projection(hf_model, pt_model)
|
||||
hf_model.logit_scale = pt_model.logit_scale
|
||||
|
||||
# Use `eos_token` so the example is more meaningful
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[config.text_config.bos_token_id]
|
||||
+ list(range(3, 77))
|
||||
+ [config.text_config.eos_token_id]
|
||||
+ [config.text_config.pad_token_id]
|
||||
]
|
||||
)
|
||||
pixel_values = torch.randn(1, 3, 224, 224)
|
||||
|
||||
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
|
||||
hf_logits_per_image = hf_outputs.logits_per_image
|
||||
hf_logits_per_text = hf_outputs.logits_per_text
|
||||
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
|
||||
|
||||
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
|
||||
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
@ -1,264 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
CLIPSegConfig,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPSegProcessor,
|
||||
CLIPSegTextConfig,
|
||||
CLIPSegVisionConfig,
|
||||
CLIPTokenizer,
|
||||
ViTImageProcessor,
|
||||
)
|
||||
|
||||
|
||||
def get_clipseg_config(model_name):
|
||||
text_config = CLIPSegTextConfig()
|
||||
vision_config = CLIPSegVisionConfig(patch_size=16)
|
||||
|
||||
use_complex_transposed_convolution = "refined" in model_name
|
||||
reduce_dim = 16 if "rd16" in model_name else 64
|
||||
|
||||
config = CLIPSegConfig.from_text_vision_configs(
|
||||
text_config,
|
||||
vision_config,
|
||||
use_complex_transposed_convolution=use_complex_transposed_convolution,
|
||||
reduce_dim=reduce_dim,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
# update prefixes
|
||||
if "clip_model" in name:
|
||||
name = name.replace("clip_model", "clip")
|
||||
if "transformer" in name:
|
||||
if "visual" in name:
|
||||
name = name.replace("visual.transformer", "vision_model")
|
||||
else:
|
||||
name = name.replace("transformer", "text_model")
|
||||
if "resblocks" in name:
|
||||
name = name.replace("resblocks", "encoder.layers")
|
||||
if "ln_1" in name:
|
||||
name = name.replace("ln_1", "layer_norm1")
|
||||
if "ln_2" in name:
|
||||
name = name.replace("ln_2", "layer_norm2")
|
||||
if "c_fc" in name:
|
||||
name = name.replace("c_fc", "fc1")
|
||||
if "c_proj" in name:
|
||||
name = name.replace("c_proj", "fc2")
|
||||
if "attn" in name and "self" not in name:
|
||||
name = name.replace("attn", "self_attn")
|
||||
# text encoder
|
||||
if "token_embedding" in name:
|
||||
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
|
||||
if "positional_embedding" in name and "visual" not in name:
|
||||
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
|
||||
if "ln_final" in name:
|
||||
name = name.replace("ln_final", "text_model.final_layer_norm")
|
||||
# vision encoder
|
||||
if "visual.class_embedding" in name:
|
||||
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
|
||||
if "visual.conv1" in name:
|
||||
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
|
||||
if "visual.positional_embedding" in name:
|
||||
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
|
||||
if "visual.ln_pre" in name:
|
||||
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
|
||||
if "visual.ln_post" in name:
|
||||
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
|
||||
# projection layers
|
||||
if "visual.proj" in name:
|
||||
name = name.replace("visual.proj", "visual_projection.weight")
|
||||
if "text_projection" in name:
|
||||
name = name.replace("text_projection", "text_projection.weight")
|
||||
# decoder
|
||||
if "trans_conv" in name:
|
||||
name = name.replace("trans_conv", "transposed_convolution")
|
||||
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
|
||||
name = "decoder." + name
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "decoder.layers")
|
||||
if "linear1" in name:
|
||||
name = name.replace("linear1", "mlp.fc1")
|
||||
if "linear2" in name:
|
||||
name = name.replace("linear2", "mlp.fc2")
|
||||
if "norm1" in name and "layer_" not in name:
|
||||
name = name.replace("norm1", "layer_norm1")
|
||||
if "norm2" in name and "layer_" not in name:
|
||||
name = name.replace("norm2", "layer_norm2")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if key.startswith("clip_model") and "attn.in_proj" in key:
|
||||
key_split = key.split(".")
|
||||
if "visual" in key:
|
||||
layer_num = int(key_split[4])
|
||||
dim = config.vision_config.hidden_size
|
||||
prefix = "vision_model"
|
||||
else:
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.text_config.hidden_size
|
||||
prefix = "text_model"
|
||||
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
|
||||
dim : dim * 2, :
|
||||
]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
elif "self_attn" in key and "out_proj" not in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[1])
|
||||
dim = config.reduce_dim
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
else:
|
||||
new_name = rename_key(key)
|
||||
if "visual_projection" in new_name or "text_projection" in new_name:
|
||||
val = val.T
|
||||
orig_state_dict[new_name] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
|
||||
config = get_clipseg_config(model_name)
|
||||
model = CLIPSegForImageSegmentation(config)
|
||||
model.eval()
|
||||
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# remove some keys
|
||||
for key in state_dict.copy():
|
||||
if key.startswith("model"):
|
||||
state_dict.pop(key, None)
|
||||
|
||||
# rename some keys
|
||||
state_dict = convert_state_dict(state_dict, config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
|
||||
raise ValueError(f"Missing keys that are not expected: {missing_keys}")
|
||||
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
|
||||
raise ValueError(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
image_processor = ViTImageProcessor(size=352)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
image = prepare_img()
|
||||
text = ["a glass", "something to fill", "wood", "a jar"]
|
||||
|
||||
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify values
|
||||
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
|
||||
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
|
||||
if model_name == "clipseg-rd64-refined":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
|
||||
)
|
||||
elif model_name == "clipseg-rd64":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
|
||||
)
|
||||
elif model_name == "clipseg-rd16":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model name {model_name} not supported.")
|
||||
|
||||
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
|
||||
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
|
||||
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and processor for {model_name} to the hub")
|
||||
model.push_to_hub(f"CIDAS/{model_name}")
|
||||
processor.push_to_hub(f"CIDAS/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="clipseg-rd64",
|
||||
type=str,
|
||||
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
|
||||
help=(
|
||||
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
|
||||
" reduce dimension)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
|
||||
" the decoder weights."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,234 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Weights conversion script for CLVP
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
|
||||
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
|
||||
}
|
||||
|
||||
dim = 1024
|
||||
sub_dim = dim // 16
|
||||
|
||||
CLVP_ENCODERS_MAPPING = {
|
||||
"text_transformer.transformer.attn_layers": "text_encoder_model",
|
||||
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
|
||||
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
|
||||
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
|
||||
"to_text_latent": "text_encoder_model.projection",
|
||||
"to_speech_latent": "speech_encoder_model.projection",
|
||||
"text_emb": "text_encoder_model.token_embedding",
|
||||
"speech_emb": "speech_encoder_model.token_embedding",
|
||||
"1.wrap.net.0": "mlp.fc1",
|
||||
"1.wrap.net.3": "mlp.fc2",
|
||||
"1.wrap": "self_attn",
|
||||
"to_out": "out_proj",
|
||||
"to_q": "q_proj",
|
||||
"to_k": "k_proj",
|
||||
"to_v": "v_proj",
|
||||
"temperature": "logit_scale",
|
||||
}
|
||||
|
||||
CLVP_DECODER_MAPPING = {
|
||||
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
|
||||
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
|
||||
"mel_attn_blocks": "group_norms",
|
||||
".norm.weight": ".weight",
|
||||
".norm.bias": ".bias",
|
||||
"text_embedding": "conditioning_encoder.text_token_embedding",
|
||||
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
|
||||
"final_norm": "speech_decoder_model.final_norm",
|
||||
"mel_head": "speech_decoder_model.lm_head",
|
||||
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
|
||||
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
|
||||
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
|
||||
"gpt.h": "speech_decoder_model.model.decoder.layers",
|
||||
"ln_1": "input_layernorm",
|
||||
"ln_2": "post_attention_layernorm",
|
||||
}
|
||||
|
||||
|
||||
def update_index(present_index):
|
||||
if present_index % 2 == 0:
|
||||
return int(present_index / 2)
|
||||
else:
|
||||
return int((present_index - 1) / 2)
|
||||
|
||||
|
||||
def convert_encoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
|
||||
if "0.0.g" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
if int(present_index) % 2 == 0:
|
||||
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
|
||||
else:
|
||||
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
|
||||
|
||||
if "transformer.attn_layers.layers" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
updated_index = update_index(int(present_index))
|
||||
updated_key = updated_key.replace(
|
||||
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
|
||||
)
|
||||
|
||||
for k, v in CLVP_ENCODERS_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def convert_decoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
if len(updated_key.split(".")) > 3:
|
||||
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
|
||||
|
||||
# for decoder attention
|
||||
if "attn.c_attn" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
|
||||
continue
|
||||
|
||||
if "attn.c_proj" in updated_key:
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
|
||||
original_weights[updated_key].squeeze(-1).T
|
||||
)
|
||||
continue
|
||||
|
||||
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
|
||||
original_weights.pop(updated_key)
|
||||
continue
|
||||
|
||||
# conditional encoder attention
|
||||
if "qkv" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
|
||||
indices = torch.arange(dim)
|
||||
index1, index2, index3 = (
|
||||
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
)
|
||||
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index1], slice2[index3], slice3[index2]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index2], slice2[index1], slice3[index3]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index3], slice2[index2], slice3[index1]],
|
||||
axis=0,
|
||||
)
|
||||
continue
|
||||
|
||||
if "proj_out" in updated_key:
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
|
||||
updated_key
|
||||
].squeeze(-1)
|
||||
continue
|
||||
|
||||
for k, v in CLVP_DECODER_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
|
||||
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
force_filename=root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
|
||||
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
|
||||
converted_checkpoint = {}
|
||||
|
||||
for each_model_name, each_model_url in _MODELS.items():
|
||||
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
|
||||
if not os.path.exists(each_model_path):
|
||||
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
|
||||
_download(url=each_model_url, root=each_model_path)
|
||||
|
||||
if each_model_name == "clvp":
|
||||
clvp_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
decoder_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# Converting the weights
|
||||
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
|
||||
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
|
||||
|
||||
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
|
||||
model = ClvpModelForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(converted_checkpoint, strict=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Model saved at {pytorch_dump_folder_path}!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model. (Please enter full path)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,214 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert ColPali weights from the original repository to the HF model format.
|
||||
|
||||
Original repository: https://github.com/illuin-tech/colpali.
|
||||
|
||||
NOTE: This script was originally run using `torch==2.5.1` and with:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf-internal \
|
||||
--push_to_hub
|
||||
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.3-merged \
|
||||
--revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.3-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.colpali import ColPaliForRetrieval
|
||||
from transformers.models.colpali.configuration_colpali import ColPaliConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
if key.startswith("custom_text_proj"):
|
||||
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
|
||||
if key.startswith("model."):
|
||||
new_key = key.replace("model.", "vlm.", 1)
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
|
||||
"model.language_model.model.embed_tokens.weight"
|
||||
].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_colpali_weights_to_hf(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
push_to_hub: bool,
|
||||
revision: Optional[str] = None,
|
||||
original_vlm_name_or_path: Optional[str] = None,
|
||||
):
|
||||
# Load the original model data
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
if original_vlm_name_or_path is not None:
|
||||
original_config._name_or_path = original_vlm_name_or_path
|
||||
if hasattr(original_config, "architectures"):
|
||||
delattr(original_config, "architectures")
|
||||
|
||||
original_state_dict = load_original_state_dict(model_id, revision=revision)
|
||||
|
||||
# Format the state_dict keys
|
||||
original_state_dict = rename_state_dict_keys(original_state_dict)
|
||||
|
||||
# Create the new config
|
||||
config = ColPaliConfig(
|
||||
vlm_config=original_config,
|
||||
embedding_dim=128, # hardcoded in the original model
|
||||
)
|
||||
config.model_type = "colpali"
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColPaliForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
# There are two ways to set the model's dtype:
|
||||
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
|
||||
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
|
||||
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
|
||||
# the new weights' dtypes match the original model.
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(ORIGINAL_DTYPE)
|
||||
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
|
||||
|
||||
# Load the original weights
|
||||
model.load_state_dict(original_state_dict)
|
||||
print("Loaded original model weights")
|
||||
|
||||
# Tie the weights (following ColPali's `__init__`` step)
|
||||
if model.vlm.language_model._tied_weights_keys is not None:
|
||||
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
|
||||
|
||||
# Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
state_dict_keys_new = set(model.state_dict().keys())
|
||||
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
|
||||
if disjoint_keys:
|
||||
raise ValueError(f"Incompatible keys: {disjoint_keys}")
|
||||
|
||||
# Save the model
|
||||
if push_to_hub:
|
||||
model.push_to_hub(output_dir, private=True)
|
||||
print(f"Model pushed to the hub at `{output_dir}`")
|
||||
else:
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
model.save_pretrained(output_dir)
|
||||
print(f"Model saved to `{output_dir}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script converts the original ColPali model to the HF model format.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Model ID of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Revision of the model to download",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_vlm_name_or_path",
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colpali_weights_to_hf(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
revision=args.revision,
|
||||
original_vlm_name_or_path=args.original_vlm_name_or_path,
|
||||
)
|
@ -1,212 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert ColQwen2 weights from the original repository to the HF model format.
|
||||
|
||||
Don't forget to manually upload the processor-related files to the HF model repository
|
||||
after running this script.
|
||||
|
||||
Original repository: https://github.com/illuin-tech/colqwen2.
|
||||
|
||||
NOTE: This script was originally run using `torch==2.5.1` and with:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
|
||||
--model_id vidore/colqwen2-v1.0-merged \
|
||||
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
|
||||
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
||||
--output_dir vidore/colqwen2-v1.0-hf-internal \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.colqwen2 import ColQwen2ForRetrieval
|
||||
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
new_state_dict: dict[str, Any] = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith("custom_text_proj"):
|
||||
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
|
||||
else:
|
||||
# The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
|
||||
# to all remaining keys.
|
||||
if key.startswith("model."):
|
||||
key = key.replace("model.", "model.language_model.")
|
||||
if key.startswith("visual."):
|
||||
key = key.replace("visual.", "model.visual.")
|
||||
new_key = "vlm." + key
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_colqwen2_weights_to_hf(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
push_to_hub: bool,
|
||||
revision: Optional[str] = None,
|
||||
original_vlm_name_or_path: Optional[str] = None,
|
||||
):
|
||||
# Load the original model data
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
if original_vlm_name_or_path is not None:
|
||||
original_config._name_or_path = original_vlm_name_or_path
|
||||
if hasattr(original_config, "architectures"):
|
||||
delattr(original_config, "architectures")
|
||||
|
||||
original_state_dict = load_original_state_dict(model_id, revision=revision)
|
||||
|
||||
# Format the state_dict keys
|
||||
original_state_dict = rename_state_dict_keys(original_state_dict)
|
||||
|
||||
# Create the new config
|
||||
config = ColQwen2Config(
|
||||
vlm_config=original_config,
|
||||
embedding_dim=128, # hardcoded in the original model
|
||||
)
|
||||
config.model_type = "colqwen2"
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
# There are two ways to set the model's dtype:
|
||||
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
|
||||
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
|
||||
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
|
||||
# the new weights' dtypes match the original model.
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(ORIGINAL_DTYPE)
|
||||
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
|
||||
|
||||
# Load the original weights
|
||||
model.load_state_dict(original_state_dict)
|
||||
print("Loaded original model weights")
|
||||
|
||||
# # Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
state_dict_keys_new = set(model.state_dict().keys())
|
||||
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
|
||||
if disjoint_keys:
|
||||
raise ValueError(f"Incompatible keys: {disjoint_keys}")
|
||||
|
||||
# Save the model
|
||||
if push_to_hub:
|
||||
model.push_to_hub(output_dir, private=True)
|
||||
print(f"Model pushed to the hub at `{output_dir}`")
|
||||
else:
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
model.save_pretrained(output_dir)
|
||||
print(f"Model saved to `{output_dir}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script converts the original ColQwen2 model to the HF model format.
|
||||
|
||||
Don't forget to manually upload the processor-related files to the HF model repository
|
||||
after running this script.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
|
||||
--model_id vidore/colqwen2-v1.0-merged \
|
||||
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
|
||||
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
||||
--output_dir vidore/colqwen2-v1.0-hf-internal \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Model ID of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Revision of the model to download",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_vlm_name_or_path",
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colqwen2_weights_to_hf(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
revision=args.revision,
|
||||
original_vlm_name_or_path=args.original_vlm_name_or_path,
|
||||
)
|
@ -1,324 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Conditional DETR checkpoints."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ConditionalDetrConfig,
|
||||
ConditionalDetrForObjectDetection,
|
||||
ConditionalDetrForSegmentation,
|
||||
ConditionalDetrImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
rename_keys = []
|
||||
for i in range(6):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
||||
|
||||
# q, k, v projections in self/cross-attention in decoder for conditional DETR
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
|
||||
)
|
||||
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
|
||||
)
|
||||
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
# for conditional DETR, also convert reference point head and query scale MLP
|
||||
rename_keys.extend(
|
||||
[
|
||||
("input_proj.weight", "input_projection.weight"),
|
||||
("input_proj.bias", "input_projection.bias"),
|
||||
("query_embed.weight", "query_position_embeddings.weight"),
|
||||
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
||||
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
||||
("class_embed.weight", "class_labels_classifier.weight"),
|
||||
("class_embed.bias", "class_labels_classifier.bias"),
|
||||
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
||||
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
||||
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
||||
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
||||
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
||||
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
|
||||
("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
|
||||
("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
|
||||
("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
|
||||
("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def rename_key(state_dict, old, new):
|
||||
val = state_dict.pop(old)
|
||||
state_dict[new] = val
|
||||
|
||||
|
||||
def rename_backbone_keys(state_dict):
|
||||
new_state_dict = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
if "backbone.0.body" in key:
|
||||
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
new_state_dict[key] = value
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def read_in_q_k_v(state_dict, is_panoptic=False):
|
||||
prefix = ""
|
||||
if is_panoptic:
|
||||
prefix = "conditional_detr."
|
||||
|
||||
# first: transformer encoder
|
||||
for i in range(6):
|
||||
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
|
||||
"""
|
||||
|
||||
# load default config
|
||||
config = ConditionalDetrConfig()
|
||||
# set backbone and dilation attributes
|
||||
if "resnet101" in model_name:
|
||||
config.backbone = "resnet101"
|
||||
if "dc5" in model_name:
|
||||
config.dilation = True
|
||||
is_panoptic = "panoptic" in model_name
|
||||
if is_panoptic:
|
||||
config.num_labels = 250
|
||||
else:
|
||||
config.num_labels = 91
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# load image processor
|
||||
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
||||
image_processor = ConditionalDetrImageProcessor(format=format)
|
||||
|
||||
# prepare image
|
||||
img = prepare_img()
|
||||
encoding = image_processor(images=img, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
|
||||
# load original model from torch hub
|
||||
conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
|
||||
state_dict = conditional_detr.state_dict()
|
||||
# rename keys
|
||||
for src, dest in rename_keys:
|
||||
if is_panoptic:
|
||||
src = "conditional_detr." + src
|
||||
rename_key(state_dict, src, dest)
|
||||
state_dict = rename_backbone_keys(state_dict)
|
||||
# query, key and value matrices need special treatment
|
||||
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
prefix = "conditional_detr.model." if is_panoptic else "model."
|
||||
for key in state_dict.copy():
|
||||
if is_panoptic:
|
||||
if (
|
||||
key.startswith("conditional_detr")
|
||||
and not key.startswith("class_labels_classifier")
|
||||
and not key.startswith("bbox_predictor")
|
||||
):
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr.model" + key[4:]] = val
|
||||
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr." + key] = val
|
||||
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
||||
continue
|
||||
else:
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
else:
|
||||
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
|
||||
# verify our conversion
|
||||
original_outputs = conditional_detr(pixel_values)
|
||||
outputs = model(pixel_values)
|
||||
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
|
||||
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
|
||||
if is_panoptic:
|
||||
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
||||
|
||||
# Save model and image processor
|
||||
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="conditional_detr_resnet50",
|
||||
type=str,
|
||||
help="Name of the CONDITIONAL_DETR model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user