mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 03:44:37 +08:00 
			
		
		
		
	Compare commits
	
		
			17 Commits
		
	
	
		
			build_cach
			...
			v4.55.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| acf295aec3 | |||
| aaa3169aa2 | |||
| ea2eee0bc8 | |||
| 956be23fff | |||
| 79a9ffc520 | |||
| 99404c7098 | |||
| 0d6908038c | |||
| b8e97fbfd2 | |||
| 586b6e693b | |||
| 95ae07d11f | |||
| 0d9032ae71 | |||
| 1d42803aac | |||
| 382717e543 | |||
| cc98f42d22 | |||
| d2f7266367 | |||
| daab2db33f | |||
| 06f8004e5c | 
@ -511,6 +511,8 @@
 | 
			
		||||
        title: GPT2
 | 
			
		||||
      - local: model_doc/gpt_bigcode
 | 
			
		||||
        title: GPTBigCode
 | 
			
		||||
      - local: model_doc/gpt_oss
 | 
			
		||||
        title: GptOss
 | 
			
		||||
      - local: model_doc/gptsan-japanese
 | 
			
		||||
        title: GPTSAN Japanese
 | 
			
		||||
      - local: model_doc/gpt-sw3
 | 
			
		||||
@ -617,8 +619,6 @@
 | 
			
		||||
        title: OLMoE
 | 
			
		||||
      - local: model_doc/open-llama
 | 
			
		||||
        title: Open-Llama
 | 
			
		||||
      - local: model_doc/openai_moe
 | 
			
		||||
        title: OpenAIMoe
 | 
			
		||||
      - local: model_doc/opt
 | 
			
		||||
        title: OPT
 | 
			
		||||
      - local: model_doc/pegasus
 | 
			
		||||
 | 
			
		||||
@ -65,6 +65,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
 | 
			
		||||
 | 
			
		||||
[[autodoc]] HqqConfig
 | 
			
		||||
 | 
			
		||||
## Mxfp4Config
 | 
			
		||||
 | 
			
		||||
[[autodoc]] Mxfp4Config
 | 
			
		||||
 | 
			
		||||
## FbgemmFp8Config
 | 
			
		||||
 | 
			
		||||
[[autodoc]] FbgemmFp8Config
 | 
			
		||||
 | 
			
		||||
@ -24,11 +24,11 @@ rendered properly in your Markdown viewer.
 | 
			
		||||
    </div>
 | 
			
		||||
</div>
 | 
			
		||||
 | 
			
		||||
# OpenAIMoE
 | 
			
		||||
# GptOss
 | 
			
		||||
 | 
			
		||||
## Overview
 | 
			
		||||
 | 
			
		||||
The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
 | 
			
		||||
The GptOss model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
 | 
			
		||||
<INSERT SHORT SUMMARY HERE>
 | 
			
		||||
 | 
			
		||||
The abstract from the paper is the following:
 | 
			
		||||
@ -43,16 +43,16 @@ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface
 | 
			
		||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## OpenAIMoeConfig
 | 
			
		||||
## GptOssConfig
 | 
			
		||||
 | 
			
		||||
[[autodoc]] OpenAIMoeConfig
 | 
			
		||||
[[autodoc]] GptOssConfig
 | 
			
		||||
 | 
			
		||||
## OpenAIMoeModel
 | 
			
		||||
## GptOssModel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] OpenAIMoeModel
 | 
			
		||||
[[autodoc]] GptOssModel
 | 
			
		||||
    - forward
 | 
			
		||||
 | 
			
		||||
## OpenAIMoeForCausalLM
 | 
			
		||||
## GptOssForCausalLM
 | 
			
		||||
 | 
			
		||||
[[autodoc]] OpenAIMoeForCausalLM
 | 
			
		||||
[[autodoc]] GptOssForCausalLM
 | 
			
		||||
    - forward
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets[audio]>=1.14.0",
 | 
			
		||||
#     "evaluate",
 | 
			
		||||
#     "librosa",
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
#     "datasets>=1.8.0",
 | 
			
		||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate>=0.12.0",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate>=0.12.0",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
#     "datasets>=1.8.0",
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
#     "datasets>=1.8.0",
 | 
			
		||||
@ -56,7 +56,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "torch>=1.5.0",
 | 
			
		||||
#     "torchvision>=0.6.0",
 | 
			
		||||
#     "datasets>=1.8.0",
 | 
			
		||||
@ -61,7 +61,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "timm",
 | 
			
		||||
#     "datasets",
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "timm",
 | 
			
		||||
#     "datasets",
 | 
			
		||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -69,7 +69,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -72,7 +72,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -74,7 +74,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
#     "protobuf",
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
#     "protobuf",
 | 
			
		||||
@ -65,7 +65,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
# You should update this to your particular problem to have better documentation of `model_type`
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "timm",
 | 
			
		||||
#     "datasets>=4.0",
 | 
			
		||||
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "albumentations >= 1.4.16",
 | 
			
		||||
#     "timm",
 | 
			
		||||
#     "datasets>=4.0",
 | 
			
		||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets >= 2.0.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
#     "accelerate",
 | 
			
		||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets >= 2.0.0",
 | 
			
		||||
#     "torch >= 1.3",
 | 
			
		||||
#     "accelerate",
 | 
			
		||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets[audio] >= 1.12.0",
 | 
			
		||||
#     "torch >= 1.5",
 | 
			
		||||
#     "torchaudio",
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets[audio] >= 1.18.0",
 | 
			
		||||
#     "torch >= 1.5",
 | 
			
		||||
#     "torchaudio",
 | 
			
		||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets[audio] >= 1.18.0",
 | 
			
		||||
#     "torch >= 1.5",
 | 
			
		||||
#     "torchaudio",
 | 
			
		||||
@ -64,7 +64,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "datasets[audio] >= 1.18.0",
 | 
			
		||||
#     "torch >= 1.5",
 | 
			
		||||
#     "torchaudio",
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.21.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
#     "protobuf",
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.21.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
#     "protobuf",
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "seqeval",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "seqeval",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -66,7 +66,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
# /// script
 | 
			
		||||
# dependencies = [
 | 
			
		||||
#     "transformers @ git+https://github.com/huggingface/transformers.git",
 | 
			
		||||
#     "transformers==4.55.2",
 | 
			
		||||
#     "accelerate >= 0.12.0",
 | 
			
		||||
#     "datasets >= 1.8.0",
 | 
			
		||||
#     "sentencepiece != 0.1.92",
 | 
			
		||||
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version(
 | 
			
		||||
    "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Checking dependencies
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
task_to_keys = {
 | 
			
		||||
    "cola": ("sentence", None),
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Dependencies and constants
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.55.0.dev0")
 | 
			
		||||
check_min_version("4.55.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -463,7 +463,7 @@ install_requires = [
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name="transformers",
 | 
			
		||||
    version="4.55.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    version="4.55.2",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
 | 
			
		||||
    author_email="transformers@huggingface.co",
 | 
			
		||||
    description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@
 | 
			
		||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
 | 
			
		||||
# in the namespace without actually importing anything (and especially none of the backends).
 | 
			
		||||
 | 
			
		||||
__version__ = "4.55.0.dev0"
 | 
			
		||||
__version__ = "4.55.2"
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
@ -57,6 +57,7 @@ from ..utils import (
 | 
			
		||||
    is_torchdynamo_exporting,
 | 
			
		||||
    logging,
 | 
			
		||||
)
 | 
			
		||||
from ..modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
 | 
			
		||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
 | 
			
		||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
 | 
			
		||||
from .candidate_generator import (
 | 
			
		||||
@ -677,30 +678,24 @@ class GenerationMixin(ContinuousMixin):
 | 
			
		||||
        if encoder_attention_mask is not None:
 | 
			
		||||
            model_inputs["attention_mask"] = encoder_attention_mask
 | 
			
		||||
 | 
			
		||||
        # 7. Prepare kwargs for flash attention to avoid recomputations
 | 
			
		||||
        if "flash" in self.config._attn_implementation and self._supports_attention_backend:
 | 
			
		||||
            tensor_kws = {"dtype": torch.int32, "device": self.device}
 | 
			
		||||
            pos = model_inputs["position_ids"][:, -1]
 | 
			
		||||
 | 
			
		||||
            cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
 | 
			
		||||
            max_length_k = int(pos.max()) + 1
 | 
			
		||||
 | 
			
		||||
            bs, seq_len = input_ids.size()
 | 
			
		||||
            q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
 | 
			
		||||
            cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
 | 
			
		||||
            max_length_q = int(q_len.max())
 | 
			
		||||
 | 
			
		||||
            (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
 | 
			
		||||
                model_inputs["position_ids"], is_packed_sequence=False
 | 
			
		||||
            )
 | 
			
		||||
            model_inputs.update(
 | 
			
		||||
                cu_seq_lens_q=cu_seq_lens_q.to(self.device),
 | 
			
		||||
                cu_seq_lens_k=cu_seq_lens_k.to(self.device),
 | 
			
		||||
                max_length_q=max_length_q,
 | 
			
		||||
                max_length_k=max_length_k,
 | 
			
		||||
            )
 | 
			
		||||
        # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
 | 
			
		||||
 | 
			
		||||
        # 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
 | 
			
		||||
        for key, value in kwargs.items():
 | 
			
		||||
            if key not in model_inputs:
 | 
			
		||||
                model_inputs[key] = value
 | 
			
		||||
 | 
			
		||||
        # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
 | 
			
		||||
        # 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
 | 
			
		||||
        model_inputs.pop("labels", None)
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ FP4_VALUES = [
 | 
			
		||||
 | 
			
		||||
# Copied from GPT_OSS repo and vllm
 | 
			
		||||
def quantize_to_mxfp4(w):
 | 
			
		||||
    from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
 | 
			
		||||
    downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
 | 
			
		||||
 | 
			
		||||
    w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
 | 
			
		||||
    w, w_scale = swizzle_mxfp4(w, w_scale)
 | 
			
		||||
@ -57,9 +57,13 @@ def quantize_to_mxfp4(w):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def swizzle_mxfp4(w, w_scale):
 | 
			
		||||
    from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
 | 
			
		||||
    from triton_kernels.tensor_details import layout
 | 
			
		||||
    from triton_kernels.tensor_details.layout import StridedLayout
 | 
			
		||||
    FP4, convert_layout, wrap_torch_tensor = (
 | 
			
		||||
        triton_kernels_hub.tensor.FP4,
 | 
			
		||||
        triton_kernels_hub.tensor.convert_layout,
 | 
			
		||||
        triton_kernels_hub.tensor.wrap_torch_tensor,
 | 
			
		||||
    )
 | 
			
		||||
    layout = triton_kernels_hub.tensor_details.layout
 | 
			
		||||
    StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
 | 
			
		||||
 | 
			
		||||
    value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
 | 
			
		||||
    w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
 | 
			
		||||
@ -173,8 +177,12 @@ class Mxfp4GptOssExperts(nn.Module):
 | 
			
		||||
        self.down_proj_precision_config = None
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
 | 
			
		||||
        from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
 | 
			
		||||
        from triton_kernels.swiglu import swiglu_fn
 | 
			
		||||
        FnSpecs, FusedActivation, matmul_ogs = (
 | 
			
		||||
            triton_kernels_hub.matmul_ogs.FnSpecs,
 | 
			
		||||
            triton_kernels_hub.matmul_ogs.FusedActivation,
 | 
			
		||||
            triton_kernels_hub.matmul_ogs.matmul_ogs,
 | 
			
		||||
        )
 | 
			
		||||
        swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
 | 
			
		||||
 | 
			
		||||
        with torch.cuda.device(hidden_states.device):
 | 
			
		||||
            act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
 | 
			
		||||
@ -211,7 +219,12 @@ def routing_torch_dist(
 | 
			
		||||
):
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
 | 
			
		||||
    GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
 | 
			
		||||
        triton_kernels_hub.routing.GatherIndx,
 | 
			
		||||
        triton_kernels_hub.routing.RoutingData,
 | 
			
		||||
        triton_kernels_hub.routing.ScatterIndx,
 | 
			
		||||
        triton_kernels_hub.routing.compute_expt_data_torch,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with torch.cuda.device(logits.device):
 | 
			
		||||
        world_size = torch.distributed.get_world_size()
 | 
			
		||||
@ -274,13 +287,16 @@ def mlp_forward(self, hidden_states):
 | 
			
		||||
    if dist.is_available() and dist.is_initialized():
 | 
			
		||||
        routing = routing_torch_dist
 | 
			
		||||
    else:
 | 
			
		||||
        from triton_kernels.routing import routing
 | 
			
		||||
        routing = triton_kernels_hub.routing.routing
 | 
			
		||||
 | 
			
		||||
        routing = routing
 | 
			
		||||
    batch_size = hidden_states.shape[0]
 | 
			
		||||
    hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
 | 
			
		||||
    router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
 | 
			
		||||
    routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
 | 
			
		||||
 | 
			
		||||
    with torch.cuda.device(router_logits.device):
 | 
			
		||||
        routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
 | 
			
		||||
 | 
			
		||||
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
 | 
			
		||||
    routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
 | 
			
		||||
    return routed_out, router_logits
 | 
			
		||||
@ -334,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
 | 
			
		||||
    from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
 | 
			
		||||
 | 
			
		||||
    PrecisionConfig, FlexCtx, InFlexData = (
 | 
			
		||||
        triton_kernels_hub.matmul_ogs.PrecisionConfig,
 | 
			
		||||
        triton_kernels_hub.matmul_ogs.FlexCtx,
 | 
			
		||||
        triton_kernels_hub.matmul_ogs.InFlexData,
 | 
			
		||||
    )
 | 
			
		||||
    from ..integrations.tensor_parallel import shard_and_distribute_module
 | 
			
		||||
 | 
			
		||||
    model = kwargs.get("model", None)
 | 
			
		||||
@ -447,6 +466,11 @@ def replace_with_mxfp4_linear(
 | 
			
		||||
):
 | 
			
		||||
    if quantization_config.dequantize:
 | 
			
		||||
        return model
 | 
			
		||||
    else:
 | 
			
		||||
        from kernels import get_kernel
 | 
			
		||||
 | 
			
		||||
        global triton_kernels_hub
 | 
			
		||||
        triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
 | 
			
		||||
 | 
			
		||||
    modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,20 +10,16 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from ..utils.import_utils import is_torch_npu_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_torch_npu_available():
 | 
			
		||||
    import math
 | 
			
		||||
 | 
			
		||||
    import torch_npu
 | 
			
		||||
    from einops import rearrange, repeat
 | 
			
		||||
    from torch_npu import npu_rotary_mul
 | 
			
		||||
    from torch_npu import npu_fusion_attention, npu_rotary_mul
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
 | 
			
		||||
@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
 | 
			
		||||
    return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
 | 
			
		||||
class IndexFirstAxis(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, input, indices):
 | 
			
		||||
        ctx.save_for_backward(indices)
 | 
			
		||||
        assert input.ndim >= 2
 | 
			
		||||
        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
 | 
			
		||||
        second_dim = other_shape.numel()
 | 
			
		||||
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
 | 
			
		||||
        # return input[indices]
 | 
			
		||||
        return torch.gather(
 | 
			
		||||
            rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
 | 
			
		||||
        ).reshape(-1, *other_shape)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        (indices,) = ctx.saved_tensors
 | 
			
		||||
        assert grad_output.ndim >= 2
 | 
			
		||||
        other_shape = grad_output.shape[1:]
 | 
			
		||||
        grad_output = rearrange(grad_output, "b ... -> b (...)")
 | 
			
		||||
        grad_input = torch.zeros(
 | 
			
		||||
            [ctx.first_axis_dim, grad_output.shape[1]],
 | 
			
		||||
            device=grad_output.device,
 | 
			
		||||
            dtype=grad_output.dtype,
 | 
			
		||||
        )
 | 
			
		||||
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
 | 
			
		||||
        # grad_input[indices] = grad_output
 | 
			
		||||
        grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
 | 
			
		||||
        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
index_first_axis = IndexFirstAxis.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
 | 
			
		||||
class IndexPutFirstAxis(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, values, indices, first_axis_dim):
 | 
			
		||||
        ctx.save_for_backward(indices)
 | 
			
		||||
        assert indices.ndim == 1
 | 
			
		||||
        assert values.ndim >= 2
 | 
			
		||||
        output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
 | 
			
		||||
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
 | 
			
		||||
        output[indices] = values
 | 
			
		||||
        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        (indices,) = ctx.saved_tensors
 | 
			
		||||
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
 | 
			
		||||
        grad_values = grad_output[indices]
 | 
			
		||||
        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
 | 
			
		||||
        return grad_values, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
index_put_first_axis = IndexPutFirstAxis.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
 | 
			
		||||
def pad_input(hidden_states, indices, batch, seqlen):
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
 | 
			
		||||
        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
 | 
			
		||||
        batch: int, batch size for the padded sequence.
 | 
			
		||||
        seqlen: int, maximum sequence length for the padded sequence.
 | 
			
		||||
    Return:
 | 
			
		||||
        hidden_states: (batch, seqlen, ...)
 | 
			
		||||
    """
 | 
			
		||||
    # dim = hidden_states.shape[-1]
 | 
			
		||||
    # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
 | 
			
		||||
    # output[indices] = hidden_states
 | 
			
		||||
    output = index_put_first_axis(hidden_states, indices, batch * seqlen)
 | 
			
		||||
    return rearrange(output, "(b s) ... -> b s ...", b=batch)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
 | 
			
		||||
def unpad_input(hidden_states, attention_mask, unused_mask=None):
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        hidden_states: (batch, seqlen, ...)
 | 
			
		||||
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
 | 
			
		||||
        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
 | 
			
		||||
    Return:
 | 
			
		||||
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
 | 
			
		||||
        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
 | 
			
		||||
        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
 | 
			
		||||
        max_seqlen_in_batch: int
 | 
			
		||||
        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
 | 
			
		||||
    """
 | 
			
		||||
    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
 | 
			
		||||
    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
 | 
			
		||||
    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 | 
			
		||||
    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
 | 
			
		||||
    max_seqlen_in_batch = seqlens_in_batch.max().item()
 | 
			
		||||
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 | 
			
		||||
    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
 | 
			
		||||
    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
 | 
			
		||||
    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 | 
			
		||||
    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
 | 
			
		||||
    # so we write custom forward and backward to make it a bit faster.
 | 
			
		||||
    return (
 | 
			
		||||
        index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
 | 
			
		||||
        indices,
 | 
			
		||||
        cu_seqlens,
 | 
			
		||||
        max_seqlen_in_batch,
 | 
			
		||||
        used_seqlens_in_batch,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def npu_flash_attn_func(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
@ -179,11 +64,11 @@ def npu_flash_attn_func(
 | 
			
		||||
 | 
			
		||||
    if not causal:
 | 
			
		||||
        head_num = q.shape[2]
 | 
			
		||||
        output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
 | 
			
		||||
        output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
 | 
			
		||||
    else:
 | 
			
		||||
        attn_mask_npu = get_attn_mask_npu(q.device)
 | 
			
		||||
        head_num = q.shape[2]
 | 
			
		||||
        output = torch_npu.npu_fusion_attention(
 | 
			
		||||
        output = npu_fusion_attention(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
 | 
			
		||||
 | 
			
		||||
    if not causal:
 | 
			
		||||
        head_num = q.shape[1]
 | 
			
		||||
        output = torch_npu.npu_fusion_attention(
 | 
			
		||||
        output = npu_fusion_attention(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
 | 
			
		||||
    else:
 | 
			
		||||
        attn_mask_npu = get_attn_mask_npu(q.device)
 | 
			
		||||
        head_num = q.shape[1]
 | 
			
		||||
        output = torch_npu.npu_fusion_attention(
 | 
			
		||||
        output = npu_fusion_attention(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -14,17 +14,15 @@
 | 
			
		||||
import inspect
 | 
			
		||||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Optional, TypedDict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from transformers.utils.import_utils import is_kernels_available
 | 
			
		||||
 | 
			
		||||
from .utils import (
 | 
			
		||||
    is_flash_attn_2_available,
 | 
			
		||||
    is_flash_attn_3_available,
 | 
			
		||||
    is_flash_attn_greater_or_equal,
 | 
			
		||||
    is_flash_attn_greater_or_equal_2_10,
 | 
			
		||||
    is_torch_npu_available,
 | 
			
		||||
    logging,
 | 
			
		||||
@ -34,18 +32,135 @@ from .utils import (
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
 | 
			
		||||
    return reshaped[indices]
 | 
			
		||||
# TODO Deprecate when all models have the attention interface
 | 
			
		||||
def flash_attn_supports_top_left_mask():
 | 
			
		||||
    if is_flash_attn_3_available():
 | 
			
		||||
        return False
 | 
			
		||||
    if is_flash_attn_2_available():
 | 
			
		||||
        return not is_flash_attn_greater_or_equal_2_10()
 | 
			
		||||
 | 
			
		||||
    from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
 | 
			
		||||
 | 
			
		||||
    return is_npu_fa2_top_left_aligned_causal_mask()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
 | 
			
		||||
# TODO Deprecate when all models have the attention interface
 | 
			
		||||
def is_flash_attn_available():
 | 
			
		||||
    return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
 | 
			
		||||
_flash_fn = None
 | 
			
		||||
_flash_varlen_fn = None
 | 
			
		||||
_pad_fn = None
 | 
			
		||||
_unpad_fn = None
 | 
			
		||||
 | 
			
		||||
# function that processes kwargs, generalized to handle any supported kwarg within the function
 | 
			
		||||
_process_flash_kwargs_fn = None
 | 
			
		||||
# exceptions where hf API doesn't match the original flash attention API
 | 
			
		||||
_hf_api_to_flash_mapping = {
 | 
			
		||||
    "dropout": "dropout_p",
 | 
			
		||||
    "sliding_window": "window_size",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _lazy_imports(implementation: Optional[str]):
 | 
			
		||||
    """
 | 
			
		||||
    FA3-compatible unpad_input function.
 | 
			
		||||
    Lazy loads the respective flash attention implementations.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        flash_attn_func: The base flash attention function.
 | 
			
		||||
        flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
 | 
			
		||||
                                e.g. for padding-free training.
 | 
			
		||||
        pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
 | 
			
		||||
        unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
 | 
			
		||||
    """
 | 
			
		||||
    is_fa2 = is_flash_attn_2_available()
 | 
			
		||||
    is_fa3 = is_flash_attn_3_available()
 | 
			
		||||
    if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
 | 
			
		||||
        from flash_attn import flash_attn_func, flash_attn_varlen_func
 | 
			
		||||
        from flash_attn.bert_padding import pad_input, unpad_input
 | 
			
		||||
    else:
 | 
			
		||||
        pad_input, unpad_input = _pad_input, _unpad_input
 | 
			
		||||
        if implementation == "flash_attention_3" or (implementation is None and is_fa3):
 | 
			
		||||
            from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
 | 
			
		||||
        elif is_torch_npu_available():
 | 
			
		||||
            from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
 | 
			
		||||
            from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
 | 
			
		||||
        # Kernels fallback
 | 
			
		||||
        else:
 | 
			
		||||
            flash_attn_func = getattr(implementation, "flash_attn_func", None)
 | 
			
		||||
            flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
 | 
			
		||||
            if flash_attn_varlen_func is None or flash_attn_func is None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Could not find the currently requested flash attention implementation at `{implementation}`."
 | 
			
		||||
                    f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _lazy_define_process_function(flash_function):
 | 
			
		||||
    """
 | 
			
		||||
    Depending on the version and kernel some features are not supported. Due to limitations in
 | 
			
		||||
    `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
 | 
			
		||||
    within `_process_flash_attention_kwargs`.
 | 
			
		||||
 | 
			
		||||
    NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
 | 
			
		||||
          This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
 | 
			
		||||
    """
 | 
			
		||||
    global _process_flash_kwargs_fn, _hf_api_to_flash_mapping
 | 
			
		||||
 | 
			
		||||
    flash_parameters = inspect.signature(flash_function).parameters
 | 
			
		||||
    process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
 | 
			
		||||
 | 
			
		||||
    supports_mapping = {}
 | 
			
		||||
    for param in process_parameters:
 | 
			
		||||
        fa_param = _hf_api_to_flash_mapping.get(param, param)
 | 
			
		||||
        supports_mapping[fa_param] = fa_param in flash_parameters
 | 
			
		||||
 | 
			
		||||
    return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_import_flash_attention(implementation: Optional[str]):
 | 
			
		||||
    """
 | 
			
		||||
    Lazy loading flash attention and returning the respective functions + flags back
 | 
			
		||||
 | 
			
		||||
    NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
 | 
			
		||||
          can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
 | 
			
		||||
    """
 | 
			
		||||
    global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
 | 
			
		||||
    if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
 | 
			
		||||
        _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
 | 
			
		||||
 | 
			
		||||
    global _process_flash_kwargs_fn
 | 
			
		||||
    if _process_flash_kwargs_fn is None:
 | 
			
		||||
        _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
 | 
			
		||||
 | 
			
		||||
    return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _index_first_axis(tensor, indices):
 | 
			
		||||
    """
 | 
			
		||||
    A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
 | 
			
		||||
    after flattening the first two dimensions of the tensor. This is functionally equivalent to
 | 
			
		||||
    FA2's `index_first_axis` and replaces the need to import it.
 | 
			
		||||
    """
 | 
			
		||||
    # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
 | 
			
		||||
    # two dimensions to get (total_tokens, ...) before indexing.
 | 
			
		||||
    reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
 | 
			
		||||
    return reshaped_tensor[indices]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
 | 
			
		||||
    """
 | 
			
		||||
    unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        hidden_states: (batch, seqlen, ...)
 | 
			
		||||
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
 | 
			
		||||
        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
 | 
			
		||||
        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
 | 
			
		||||
@ -69,14 +184,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
 | 
			
		||||
def _pad_input(hidden_states, indices, batch, seqlen):
 | 
			
		||||
    """
 | 
			
		||||
    FA3-compatible pad_input function.
 | 
			
		||||
    pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
 | 
			
		||||
        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
 | 
			
		||||
        batch: int, batch size for the padded sequence.
 | 
			
		||||
        seqlen: int, maximum sequence length for the padded sequence.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        hidden_states: (batch, seqlen, ...)
 | 
			
		||||
    """
 | 
			
		||||
@ -89,9 +206,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
 | 
			
		||||
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
 | 
			
		||||
    """
 | 
			
		||||
    Retrieves indexing data required to repad unpadded (ragged) tensors.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        attention_mask (`torch.Tensor`):
 | 
			
		||||
            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        indices (`torch.Tensor`):
 | 
			
		||||
            The indices of non-masked tokens from the flattened input sequence.
 | 
			
		||||
@ -125,6 +244,7 @@ def _upad_input(
 | 
			
		||||
    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
 | 
			
		||||
    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
 | 
			
		||||
    tensors for query, key, value tensors.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        query_layer (`torch.Tensor`):
 | 
			
		||||
            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
 | 
			
		||||
@ -138,6 +258,7 @@ def _upad_input(
 | 
			
		||||
            Target length.
 | 
			
		||||
        unpad_input_func:
 | 
			
		||||
            The function to use for unpadding the input tensors.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        query_layer (`torch.Tensor`):
 | 
			
		||||
            Query state without padding. Shape: (total_target_length, num_heads, head_dim).
 | 
			
		||||
@ -190,12 +311,79 @@ def _upad_input(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_from_posids(query, key, value, position_ids):
 | 
			
		||||
def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
 | 
			
		||||
    """
 | 
			
		||||
    This function returns all the necessary kwargs to call `flash_attn_varlen_func`
 | 
			
		||||
    extracted from position_ids. The `position_ids` can be either packed sequence or
 | 
			
		||||
    the usual padded position ids, for example in inference time.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        position_ids (`torch.Tensor`):
 | 
			
		||||
            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
 | 
			
		||||
        is_packed_sequence (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether the input position ids are a packed sequence or not.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
 | 
			
		||||
            The cumulative sequence lengths for the target (query) and source (key, value), used to index into
 | 
			
		||||
            ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
 | 
			
		||||
        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
 | 
			
		||||
            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
 | 
			
		||||
            `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
 | 
			
		||||
    """
 | 
			
		||||
    # If the lengths are not equal, most probably we are in decoding stage with cache
 | 
			
		||||
    # In that case the position ids will not always start with `0` and we need a better way to infer
 | 
			
		||||
    # cumulative seq lengths.
 | 
			
		||||
    if not is_packed_sequence:
 | 
			
		||||
        tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
 | 
			
		||||
 | 
			
		||||
        last_position_ids = position_ids[:, -1]
 | 
			
		||||
        q_len = (
 | 
			
		||||
            torch.ones(position_ids.size(0), **tensor_kwargs)
 | 
			
		||||
            if position_ids.shape[-1] == 1
 | 
			
		||||
            else last_position_ids.add(1)
 | 
			
		||||
        )
 | 
			
		||||
        cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0)
 | 
			
		||||
        cu_seq_lens_k = torch.cat(
 | 
			
		||||
            [torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        max_length_q = int(q_len.max())
 | 
			
		||||
        max_length_k = int(last_position_ids.max()) + 1
 | 
			
		||||
    else:
 | 
			
		||||
        position_ids = position_ids.flatten()
 | 
			
		||||
        indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
        cu_seq_lens_q = torch.cat(
 | 
			
		||||
            (
 | 
			
		||||
                indices_q[position_ids == 0],
 | 
			
		||||
                torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        cu_seq_lens_k = cu_seq_lens_q
 | 
			
		||||
 | 
			
		||||
        # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
 | 
			
		||||
        # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
 | 
			
		||||
        # for some models (e.g. qwen2-vl).
 | 
			
		||||
        max_length_q = cu_seq_lens_q.diff().max()
 | 
			
		||||
        # NOTE: With torch compile, this will cause a graph break if you don't set
 | 
			
		||||
        # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
 | 
			
		||||
        # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
 | 
			
		||||
        # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
 | 
			
		||||
        # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
 | 
			
		||||
        max_length_q = max_length_q.item()
 | 
			
		||||
        max_length_k = max_length_q
 | 
			
		||||
 | 
			
		||||
    return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_from_posids(query, key, value, position_ids, query_length):
 | 
			
		||||
    """
 | 
			
		||||
    This function returns necessary arguments to call `flash_attn_varlen_func`.
 | 
			
		||||
    All three query, key, value states will be flattened.
 | 
			
		||||
    Cumulative lengths of each examples in the batch will be extracted from position_ids.
 | 
			
		||||
    NOTE: ideally cumulative lengths should be prepared at the data collator stage
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        query (`torch.Tensor`):
 | 
			
		||||
            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
 | 
			
		||||
@ -205,6 +393,9 @@ def _prepare_from_posids(query, key, value, position_ids):
 | 
			
		||||
            Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
 | 
			
		||||
        position_ids (`torch.Tensor`):
 | 
			
		||||
            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
 | 
			
		||||
        query_length (`int`):
 | 
			
		||||
            Sequence length of the input queries.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        query (`torch.Tensor`):
 | 
			
		||||
            Query state without padding. Shape: (total_target_length, num_heads, head_dim).
 | 
			
		||||
@ -219,123 +410,152 @@ def _prepare_from_posids(query, key, value, position_ids):
 | 
			
		||||
        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
 | 
			
		||||
            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
 | 
			
		||||
    """
 | 
			
		||||
    kv_length = key.shape[1]
 | 
			
		||||
    is_packed_sequence = query_length == kv_length
 | 
			
		||||
 | 
			
		||||
    query = query.contiguous().view(-1, query.size(-2), query.size(-1))
 | 
			
		||||
    key = key.contiguous().view(-1, key.size(-2), key.size(-1))
 | 
			
		||||
    value = value.contiguous().view(-1, value.size(-2), value.size(-1))
 | 
			
		||||
 | 
			
		||||
    position_ids = position_ids.flatten()
 | 
			
		||||
    indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    cu_seq_lens = torch.cat(
 | 
			
		||||
        (
 | 
			
		||||
            indices_q[position_ids == 0],
 | 
			
		||||
            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
 | 
			
		||||
        )
 | 
			
		||||
    (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
 | 
			
		||||
        position_ids, is_packed_sequence=is_packed_sequence
 | 
			
		||||
    )
 | 
			
		||||
    # NOTE: With torch compile, this will cause a graph break if you don't set
 | 
			
		||||
    # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
 | 
			
		||||
    # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
 | 
			
		||||
    # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
 | 
			
		||||
    # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
 | 
			
		||||
    # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
 | 
			
		||||
    # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
 | 
			
		||||
    # for some models (e.g. qwen2-vl).
 | 
			
		||||
    max_length = cu_seq_lens.diff().max().item()
 | 
			
		||||
    return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
 | 
			
		||||
 | 
			
		||||
    return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
 | 
			
		||||
    warnings.warn(
 | 
			
		||||
        "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
 | 
			
		||||
        "The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.",
 | 
			
		||||
        FutureWarning,
 | 
			
		||||
    )
 | 
			
		||||
    return _prepare_from_posids(query, key, value, position_ids)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
 | 
			
		||||
def _is_packed_sequence(position_ids, batch_size):
 | 
			
		||||
    """
 | 
			
		||||
    Check the position ids whether packed sequences are indicated or not
 | 
			
		||||
        1. Position ids exist
 | 
			
		||||
        2. Flattened sequences only are supported
 | 
			
		||||
        3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
 | 
			
		||||
    """
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    increasing_position_sequences = (
 | 
			
		||||
        torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
 | 
			
		||||
    )
 | 
			
		||||
    return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fa_peft_integration_check(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    target_dtype: Optional[torch.dtype] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    PEFT usually casts the layer norms in float32 for training stability reasons
 | 
			
		||||
    therefore the input hidden states gets silently casted in float32. Hence, we need
 | 
			
		||||
    cast them back in float16 / bfloat16 just to be sure everything works as expected.
 | 
			
		||||
    This might slowdown training & inference so it is recommended to not cast the LayerNorms!
 | 
			
		||||
    """
 | 
			
		||||
    if target_dtype and q.dtype == torch.float32:
 | 
			
		||||
        logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
 | 
			
		||||
        q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
 | 
			
		||||
    return q, k, v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _lazy_imports(impl: Optional[str]):
 | 
			
		||||
    # returns funcs and pad/unpad based on impl
 | 
			
		||||
    is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
 | 
			
		||||
    is_fa3 = is_flash_attn_3_available()
 | 
			
		||||
    if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
 | 
			
		||||
        try:
 | 
			
		||||
            from flash_attn import flash_attn_func, flash_attn_varlen_func
 | 
			
		||||
            from flash_attn.bert_padding import pad_input, unpad_input
 | 
			
		||||
 | 
			
		||||
            return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
 | 
			
		||||
 | 
			
		||||
        except ImportError as e:
 | 
			
		||||
            if not globals().get("use_remote_fa2", None):
 | 
			
		||||
                use_remote_fa2 = (
 | 
			
		||||
                    input(
 | 
			
		||||
                        "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
 | 
			
		||||
                    )
 | 
			
		||||
                    .strip()
 | 
			
		||||
                    .lower()
 | 
			
		||||
                )
 | 
			
		||||
                globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
 | 
			
		||||
            if globals()["use_remote_fa2"]:
 | 
			
		||||
                if not is_kernels_available():
 | 
			
		||||
                    raise ImportError("You need to install kernels: `pip install kernels`")
 | 
			
		||||
                from kernels import get_kernel
 | 
			
		||||
 | 
			
		||||
                impl = get_kernel("kernels-community/flash-attn")
 | 
			
		||||
                pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
 | 
			
		||||
                return (
 | 
			
		||||
                    getattr(impl, "flash_attn_func", None),
 | 
			
		||||
                    getattr(impl, "flash_attn_varlen_func"),
 | 
			
		||||
                    pad_input,
 | 
			
		||||
                    unpad_input,
 | 
			
		||||
                    True,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "Failed to import flash attention 2, please install it or use another implementation."
 | 
			
		||||
                ) from e
 | 
			
		||||
    if impl == "flash_attention_3" or (impl is None and is_fa3):
 | 
			
		||||
        from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
 | 
			
		||||
 | 
			
		||||
        pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
 | 
			
		||||
        return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
 | 
			
		||||
    else:
 | 
			
		||||
        pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
 | 
			
		||||
        return (
 | 
			
		||||
            getattr(impl, "flash_attn_func", None),
 | 
			
		||||
            getattr(impl, "flash_attn_varlen_func"),
 | 
			
		||||
            pad_input,
 | 
			
		||||
            unpad_input,
 | 
			
		||||
            True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_flash_supports_window = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_flash_attn_available():
 | 
			
		||||
    return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def flash_attn_supports_top_left_mask():
 | 
			
		||||
    if is_flash_attn_3_available():
 | 
			
		||||
        return False
 | 
			
		||||
    if is_flash_attn_2_available():
 | 
			
		||||
        return not is_flash_attn_greater_or_equal_2_10()
 | 
			
		||||
 | 
			
		||||
    from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
 | 
			
		||||
 | 
			
		||||
    return is_npu_fa2_top_left_aligned_causal_mask()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FlashAttentionKwargs(TypedDict, total=False):
 | 
			
		||||
    """
 | 
			
		||||
    Keyword arguments for Flash Attention with Compile.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        cumulative_seqlens_q (`torch.LongTensor`, *optional*)
 | 
			
		||||
            Gets cumulative sequence length for query state.
 | 
			
		||||
        cumulative_seqlens_k (`torch.LongTensor`, *optional*)
 | 
			
		||||
            Gets cumulative sequence length for key state.
 | 
			
		||||
        max_length_q (`int`, *optional*):
 | 
			
		||||
            Maximum sequence length for query state.
 | 
			
		||||
        max_length_k (`int`, *optional*):
 | 
			
		||||
            Maximum sequence length for key state.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    cumulative_seqlens_q: Optional[torch.LongTensor]
 | 
			
		||||
    cumulative_seqlens_k: Optional[torch.LongTensor]
 | 
			
		||||
    max_length_q: Optional[int]
 | 
			
		||||
    max_length_k: Optional[int]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _process_flash_attention_kwargs(
 | 
			
		||||
    query_length: int,
 | 
			
		||||
    key_length: int,
 | 
			
		||||
    is_causal: bool,
 | 
			
		||||
    dropout: float = 0.0,
 | 
			
		||||
    softmax_scale: Optional[float] = None,
 | 
			
		||||
    sliding_window: Optional[int] = None,
 | 
			
		||||
    use_top_left_mask: bool = False,
 | 
			
		||||
    softcap: Optional[float] = None,
 | 
			
		||||
    deterministic: Optional[bool] = None,
 | 
			
		||||
    s_aux: Optional[torch.Tensor] = None,
 | 
			
		||||
    supports_mapping: Optional[dict[str, bool]] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Returns a set of kwargs that are passed down to the according flash attention function based on
 | 
			
		||||
    requested features and whether it is supported - depends on the version and kernel implementation
 | 
			
		||||
    which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be
 | 
			
		||||
    inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        query_length (`int`):
 | 
			
		||||
            Length of the query states
 | 
			
		||||
        key_length (`int`):
 | 
			
		||||
            Length of the key states
 | 
			
		||||
        is_causal (`bool`):
 | 
			
		||||
            Whether we perform causal (decoder) attention or full attention.
 | 
			
		||||
        dropout (`float`):
 | 
			
		||||
            Attention dropout.
 | 
			
		||||
        softmax_scale (`float`, *optional*):
 | 
			
		||||
            The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
 | 
			
		||||
        sliding_window (`int`, *optional*):
 | 
			
		||||
            The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
 | 
			
		||||
        use_top_left_mask (`bool`):
 | 
			
		||||
            Deprecated behavior of older versions of flash attention requiring different masking.
 | 
			
		||||
        softcap (`float`, *optional*):
 | 
			
		||||
            Softcap for the attention logits, used e.g. in gemma2.
 | 
			
		||||
        deterministic (`bool`, *optional*):
 | 
			
		||||
            Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
 | 
			
		||||
        s_aux (`torch.Tensor`, *optional*):
 | 
			
		||||
            Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
 | 
			
		||||
    Return:
 | 
			
		||||
        flash_kwargs (`dict`):
 | 
			
		||||
            A dict of kwargs that are requested and supported.
 | 
			
		||||
    """
 | 
			
		||||
    flash_kwargs = {
 | 
			
		||||
        "causal": is_causal and not (use_top_left_mask and query_length == 1),
 | 
			
		||||
        "softmax_scale": softmax_scale,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if supports_mapping["dropout_p"]:
 | 
			
		||||
        flash_kwargs["dropout_p"] = dropout
 | 
			
		||||
 | 
			
		||||
    if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
 | 
			
		||||
        flash_kwargs["window_size"] = (sliding_window, sliding_window)
 | 
			
		||||
 | 
			
		||||
    if supports_mapping["deterministic"]:
 | 
			
		||||
        flash_kwargs["deterministic"] = (
 | 
			
		||||
            deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if supports_mapping["softcap"] and softcap is not None:
 | 
			
		||||
        flash_kwargs["softcap"] = softcap
 | 
			
		||||
 | 
			
		||||
    # Only within kernel implementation atm
 | 
			
		||||
    if supports_mapping["s_aux"] and s_aux is not None:
 | 
			
		||||
        flash_kwargs["s_aux"] = s_aux
 | 
			
		||||
 | 
			
		||||
    return flash_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _flash_attention_forward(
 | 
			
		||||
@ -360,100 +580,121 @@ def _flash_attention_forward(
 | 
			
		||||
    implementation: Optional[str] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
 | 
			
		||||
        flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
 | 
			
		||||
        globals()["_flash_fn"] = flash_fn
 | 
			
		||||
        globals()["_flash_varlen_fn"] = flash_varlen_fn
 | 
			
		||||
        globals()["_pad_fn"] = pad_fn
 | 
			
		||||
        globals()["_unpad_fn"] = unpad_fn
 | 
			
		||||
        globals()["_is_fa3"] = is_fa3
 | 
			
		||||
        flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
 | 
			
		||||
        globals()["_flash_supports_window"] = flash_supports_window
 | 
			
		||||
    else:
 | 
			
		||||
        flash_fn = globals()["_flash_fn"]
 | 
			
		||||
        flash_varlen_fn = globals()["_flash_varlen_fn"]
 | 
			
		||||
        pad_fn = globals()["_pad_fn"]
 | 
			
		||||
        unpad_fn = globals()["_unpad_fn"]
 | 
			
		||||
        is_fa3 = globals()["_is_fa3"]
 | 
			
		||||
        flash_supports_window = globals()["_flash_supports_window"]
 | 
			
		||||
    """
 | 
			
		||||
    Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 | 
			
		||||
    first unpad the input, then computes the attention scores and pad the final attention scores.
 | 
			
		||||
 | 
			
		||||
    causal = is_causal and not (use_top_left_mask and query_length == 1)
 | 
			
		||||
    use_sw = (
 | 
			
		||||
        (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
 | 
			
		||||
    (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        query_states (`torch.Tensor`):
 | 
			
		||||
            Input query states to be passed to Flash Attention API
 | 
			
		||||
        key_states (`torch.Tensor`):
 | 
			
		||||
            Input key states to be passed to Flash Attention API
 | 
			
		||||
        value_states (`torch.Tensor`):
 | 
			
		||||
            Input value states to be passed to Flash Attention API
 | 
			
		||||
        attention_mask (`torch.Tensor`, *optional*):
 | 
			
		||||
            The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
 | 
			
		||||
            position of padding tokens and 1 for the position of non-padding tokens.
 | 
			
		||||
        implementation (`str`, *optional*):
 | 
			
		||||
            The attention implementation to use. If None, will default to the one based on the environment.
 | 
			
		||||
    """
 | 
			
		||||
    (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
 | 
			
		||||
        implementation
 | 
			
		||||
    )
 | 
			
		||||
    flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
 | 
			
		||||
    if not is_fa3:
 | 
			
		||||
        flash_kwargs["dropout_p"] = dropout
 | 
			
		||||
    if is_flash_attn_greater_or_equal("2.4.1"):
 | 
			
		||||
        det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
 | 
			
		||||
        flash_kwargs["deterministic"] = det
 | 
			
		||||
    if softcap is not None:
 | 
			
		||||
        flash_kwargs["softcap"] = softcap
 | 
			
		||||
    if "s_aux" in kwargs:
 | 
			
		||||
        flash_kwargs["s_aux"] = kwargs.get("s_aux")
 | 
			
		||||
 | 
			
		||||
    # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
 | 
			
		||||
    query_states, key_states, value_states = fa_peft_integration_check(
 | 
			
		||||
        query_states, key_states, value_states, target_dtype
 | 
			
		||||
    )
 | 
			
		||||
    use_mask = position_ids is not None or all(
 | 
			
		||||
        k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]
 | 
			
		||||
 | 
			
		||||
    # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
 | 
			
		||||
    flash_kwargs = process_flash_kwargs_fn(
 | 
			
		||||
        query_length=query_length,
 | 
			
		||||
        key_length=key_states.size(1),
 | 
			
		||||
        is_causal=is_causal,
 | 
			
		||||
        dropout=dropout,
 | 
			
		||||
        softmax_scale=softmax_scale,
 | 
			
		||||
        sliding_window=sliding_window,
 | 
			
		||||
        use_top_left_mask=use_top_left_mask,
 | 
			
		||||
        softcap=softcap,
 | 
			
		||||
        deterministic=deterministic,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
 | 
			
		||||
    # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
 | 
			
		||||
    # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
 | 
			
		||||
    # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
 | 
			
		||||
    #
 | 
			
		||||
    # NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model.
 | 
			
		||||
    # See #39121 for more information.
 | 
			
		||||
    is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
 | 
			
		||||
    is_fa_with_varlen_kwargs = all(
 | 
			
		||||
        kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Contains at least one padding token in the sequence
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
 | 
			
		||||
        q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
 | 
			
		||||
            query_states, key_states, value_states, attention_mask, query_length, unpad_fn
 | 
			
		||||
        )
 | 
			
		||||
        # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
 | 
			
		||||
 | 
			
		||||
        # TODO for now this is required to work with
 | 
			
		||||
        # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
 | 
			
		||||
        if "mps" in str(q.device):
 | 
			
		||||
            cu_k = cu_k.clone()
 | 
			
		||||
            cu_seq_lens_k = cu_seq_lens_k.clone()
 | 
			
		||||
 | 
			
		||||
        out_unpad = flash_varlen_fn(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
            cu_seqlens_q=cu_q.to(torch.int32),
 | 
			
		||||
            cu_seqlens_k=cu_k.to(torch.int32),
 | 
			
		||||
            max_seqlen_q=mq,
 | 
			
		||||
            max_seqlen_k=mk,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            cu_seqlens_q=cu_seq_lens_q,
 | 
			
		||||
            cu_seqlens_k=cu_seq_lens_k,
 | 
			
		||||
            max_seqlen_q=max_length_q,
 | 
			
		||||
            max_seqlen_k=max_length_k,
 | 
			
		||||
            **flash_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(out_unpad, tuple):
 | 
			
		||||
            out_unpad = out_unpad[0]
 | 
			
		||||
        out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
 | 
			
		||||
    elif use_mask:
 | 
			
		||||
 | 
			
		||||
        out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
 | 
			
		||||
 | 
			
		||||
    # Padding free, i.e. sequences flattened into one total sequence
 | 
			
		||||
    elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
 | 
			
		||||
        if cu_seq_lens_q is None or cu_seq_lens_k is None:
 | 
			
		||||
            if position_ids is None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
 | 
			
		||||
                )
 | 
			
		||||
            q, k, v, idx, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
 | 
			
		||||
                query_states, key_states, value_states, position_ids
 | 
			
		||||
            q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
 | 
			
		||||
                query_states, key_states, value_states, position_ids, query_length=query_length
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
 | 
			
		||||
            k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
 | 
			
		||||
            v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
 | 
			
		||||
            mq, mk = max_length_q, max_length_k
 | 
			
		||||
            cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
 | 
			
		||||
 | 
			
		||||
        # TODO for now this is required to work with
 | 
			
		||||
        # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
 | 
			
		||||
        if "mps" in str(q.device):
 | 
			
		||||
            cu_k = cu_k.clone()
 | 
			
		||||
            cu_seq_lens_k = cu_seq_lens_k.clone()
 | 
			
		||||
 | 
			
		||||
        out = flash_varlen_fn(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
            cu_seqlens_q=cu_q.to(torch.int32),
 | 
			
		||||
            cu_seqlens_k=cu_k.to(torch.int32),
 | 
			
		||||
            max_seqlen_q=mq,
 | 
			
		||||
            max_seqlen_k=mk,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            cu_seqlens_q=cu_seq_lens_q,
 | 
			
		||||
            cu_seqlens_k=cu_seq_lens_k,
 | 
			
		||||
            max_seqlen_q=max_length_q,
 | 
			
		||||
            max_seqlen_k=max_length_k,
 | 
			
		||||
            **flash_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(out, tuple):
 | 
			
		||||
            out = out[0]
 | 
			
		||||
        out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
 | 
			
		||||
    else:
 | 
			
		||||
        out = flash_fn(
 | 
			
		||||
            query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return out[0] if isinstance(out, tuple) else out
 | 
			
		||||
        out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
 | 
			
		||||
 | 
			
		||||
    # No padding
 | 
			
		||||
    else:
 | 
			
		||||
        out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
 | 
			
		||||
        if isinstance(out, tuple):
 | 
			
		||||
            out = out[0]
 | 
			
		||||
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
@ -74,6 +74,7 @@ from .integrations.tensor_parallel import (
 | 
			
		||||
)
 | 
			
		||||
from .loss.loss_utils import LOSS_MAPPING
 | 
			
		||||
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
 | 
			
		||||
from .modeling_flash_attention_utils import lazy_import_flash_attention
 | 
			
		||||
from .pytorch_utils import (  # noqa: F401
 | 
			
		||||
    Conv1D,
 | 
			
		||||
    apply_chunking_to_forward,
 | 
			
		||||
@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
 | 
			
		||||
    _pp_plan = None
 | 
			
		||||
 | 
			
		||||
    # This flag signal that the model can be used as an efficient backend in TGI and vLLM
 | 
			
		||||
    # In practice, it means that they support attention interface functions, fully pass the kwargs
 | 
			
		||||
    # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
 | 
			
		||||
    # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
 | 
			
		||||
    _supports_attention_backend = False
 | 
			
		||||
    _can_record_outputs = None
 | 
			
		||||
@ -2740,6 +2741,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
 | 
			
		||||
                    if attention_wrapper is None:
 | 
			
		||||
                        attention_wrapper = flash_attention_forward
 | 
			
		||||
                    kernel_function = partial(attention_wrapper, implementation=kernel)
 | 
			
		||||
                    lazy_import_flash_attention(kernel)
 | 
			
		||||
                elif kernel_name is not None:
 | 
			
		||||
                    kernel_function = getattr(kernel, kernel_name)
 | 
			
		||||
                ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
 | 
			
		||||
@ -2755,7 +2757,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
 | 
			
		||||
                attn_implementation = "sdpa"  # Try to fallback to sdpa in this case
 | 
			
		||||
            return attn_implementation
 | 
			
		||||
        else:
 | 
			
		||||
            return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
 | 
			
		||||
            attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
 | 
			
		||||
 | 
			
		||||
            # preload flash attention here to allow compile with fullgraph
 | 
			
		||||
            if applicable_attn_implementation.startswith("flash_attention"):
 | 
			
		||||
                lazy_import_flash_attention(applicable_attn_implementation)
 | 
			
		||||
 | 
			
		||||
            return attn_implementation
 | 
			
		||||
 | 
			
		||||
    def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
 | 
			
		||||
        requested_attention = "sdpa" if _requested_attention is None else _requested_attention
 | 
			
		||||
 | 
			
		||||
@ -1,269 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import gc
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    Aimv2Config,
 | 
			
		||||
    Aimv2Model,
 | 
			
		||||
    Aimv2VisionConfig,
 | 
			
		||||
    Aimv2VisionModel,
 | 
			
		||||
    AutoImageProcessor,
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
 | 
			
		||||
    # Embeddings
 | 
			
		||||
    r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
 | 
			
		||||
    r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
 | 
			
		||||
    r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
 | 
			
		||||
    # Encoder Layers
 | 
			
		||||
    r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
 | 
			
		||||
    r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
 | 
			
		||||
    r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
 | 
			
		||||
    r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
 | 
			
		||||
    r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
 | 
			
		||||
    # Normalization Layers
 | 
			
		||||
    r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
 | 
			
		||||
    r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
 | 
			
		||||
    # Final Norm
 | 
			
		||||
    r"trunk.post_trunk_norm": r"rms_norm",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
 | 
			
		||||
    # Vision Embeddings
 | 
			
		||||
    r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
 | 
			
		||||
    r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
 | 
			
		||||
    r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
 | 
			
		||||
    # Vision Encoder Layers
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
 | 
			
		||||
    # Normalization Layers
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
 | 
			
		||||
    r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
 | 
			
		||||
    r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
 | 
			
		||||
    r"image_projector": r"visual_projection",
 | 
			
		||||
    # Vision Head
 | 
			
		||||
    r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
 | 
			
		||||
    r"image_encoder.head.k": r"vision_model.head.k_proj",
 | 
			
		||||
    r"image_encoder.head.v": r"vision_model.head.v_proj",
 | 
			
		||||
    r"image_encoder.head.linear": r"vision_model.head.output_proj",
 | 
			
		||||
    # Text Embeddings
 | 
			
		||||
    r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
 | 
			
		||||
    r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
 | 
			
		||||
    # Text Encoder Layers
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
 | 
			
		||||
    # Text Normalization Layers
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
 | 
			
		||||
    r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
 | 
			
		||||
    r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
 | 
			
		||||
    r"text_projector": r"text_projection",
 | 
			
		||||
    r"log_logit_scale": r"logit_scale",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
 | 
			
		||||
    # Download only the model.safetensors file
 | 
			
		||||
    directory_path = snapshot_download(
 | 
			
		||||
        repo_id=model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        allow_patterns=["model.safetensors"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    safetensor_path = f"{directory_path}/model.safetensors"
 | 
			
		||||
 | 
			
		||||
    with safe_open(safetensor_path, framework="pt", device="cpu") as f:
 | 
			
		||||
        for key in f.keys():
 | 
			
		||||
            original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
 | 
			
		||||
    """Converts state dict keys from the old format to the new format."""
 | 
			
		||||
 | 
			
		||||
    output_dict = {}
 | 
			
		||||
    if state_dict_keys is not None:
 | 
			
		||||
        old_text = "\n".join(state_dict_keys)
 | 
			
		||||
        new_text = old_text
 | 
			
		||||
        for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
 | 
			
		||||
            if replacement is None:
 | 
			
		||||
                new_text = re.sub(pattern, "", new_text)  # an empty line
 | 
			
		||||
                continue
 | 
			
		||||
            new_text = re.sub(pattern, replacement, new_text)
 | 
			
		||||
        output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
 | 
			
		||||
    return output_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_qkv_tensor(key, tensor):
 | 
			
		||||
    """Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
 | 
			
		||||
 | 
			
		||||
    new_keys = ["q_proj", "k_proj", "v_proj"]
 | 
			
		||||
    split_size = tensor.shape[0] // 3
 | 
			
		||||
    split_tensors = torch.split(tensor, split_size, dim=0)
 | 
			
		||||
 | 
			
		||||
    return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model_config_mapping(model_id: str):
 | 
			
		||||
    """Determines the correct model, config, and key mappings based on the checkpoint name."""
 | 
			
		||||
 | 
			
		||||
    if model_id == "apple/aimv2-large-patch14-224-lit":
 | 
			
		||||
        return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
 | 
			
		||||
    else:
 | 
			
		||||
        return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_model(
 | 
			
		||||
    hf_repo_id: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    safe_serialization: bool = True,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Converts a model checkpoint to Hugging Face format and saves it.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        hf_repo_id (str): The Hugging Face repo ID to load from.
 | 
			
		||||
        output_dir (str): The directory to save the converted model.
 | 
			
		||||
        safe_serialization (bool): Whether to use safe serialization.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        model: The reloaded Hugging Face model.
 | 
			
		||||
    """
 | 
			
		||||
    os.makedirs(output_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # Get the appropriate model, config, and key mapping
 | 
			
		||||
    model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
 | 
			
		||||
 | 
			
		||||
    # Load config and original state dict
 | 
			
		||||
    config = config_class.from_pretrained(hf_repo_id)
 | 
			
		||||
 | 
			
		||||
    # Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
 | 
			
		||||
    if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
 | 
			
		||||
        config.use_head = False
 | 
			
		||||
 | 
			
		||||
    if hf_repo_id == "apple/aimv2-large-patch14-native":
 | 
			
		||||
        config.is_native = True
 | 
			
		||||
 | 
			
		||||
    original_state_dict = load_original_state_dict(hf_repo_id)
 | 
			
		||||
 | 
			
		||||
    print("Converting model...")
 | 
			
		||||
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
    result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
 | 
			
		||||
    all_keys = list(original_state_dict.keys())
 | 
			
		||||
 | 
			
		||||
    for key in all_keys:
 | 
			
		||||
        value = original_state_dict[key]
 | 
			
		||||
        new_key = result.pop(key)
 | 
			
		||||
 | 
			
		||||
        if "qkv" in new_key:
 | 
			
		||||
            qkv_state_dict = split_qkv_tensor(new_key, value)
 | 
			
		||||
            state_dict.update(qkv_state_dict)
 | 
			
		||||
        else:
 | 
			
		||||
            state_dict[new_key] = value
 | 
			
		||||
 | 
			
		||||
        # Check if position embeddings exist before squeezing
 | 
			
		||||
        if new_key.endswith("position_embedding.weight"):
 | 
			
		||||
            state_dict[new_key] = value.squeeze(0)
 | 
			
		||||
 | 
			
		||||
    print(f"Loading the checkpoint in a {model_class.__name__}.")
 | 
			
		||||
    model = model_class(config)
 | 
			
		||||
    model.load_state_dict(state_dict, strict=True, assign=True)
 | 
			
		||||
    print("Checkpoint loaded successfully.")
 | 
			
		||||
 | 
			
		||||
    print("Saving the model.")
 | 
			
		||||
    model.save_pretrained(output_dir, safe_serialization=safe_serialization)
 | 
			
		||||
    del state_dict, model
 | 
			
		||||
    gc.collect()
 | 
			
		||||
 | 
			
		||||
    print("Reloading the model to check if it's saved correctly.")
 | 
			
		||||
    model = model_class.from_pretrained(output_dir, device_map="auto")
 | 
			
		||||
    print("Model reloaded successfully.")
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_image_processor(hf_repo_id: str, output_dir: str):
 | 
			
		||||
    if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
 | 
			
		||||
        image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
 | 
			
		||||
    else:
 | 
			
		||||
        image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
 | 
			
		||||
    image_processor.save_pretrained(output_dir)
 | 
			
		||||
    return image_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_repo_id",
 | 
			
		||||
        default="apple/aimv2-large-patch14-224",
 | 
			
		||||
        help="Location of official weights from apple on HF",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        default="aimv2_model",
 | 
			
		||||
        help="Location to write the converted model and processor",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action=argparse.BooleanOptionalAction,
 | 
			
		||||
        help="Whether or not to push the converted model to the huggingface hub.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hub_repo_id",
 | 
			
		||||
        default=None,
 | 
			
		||||
        help="Huggingface hub repo to write the converted model and processor",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    model = write_model(
 | 
			
		||||
        hf_repo_id=args.hf_repo_id,
 | 
			
		||||
        output_dir=args.output_dir,
 | 
			
		||||
        safe_serialization=args.safe_serialization,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    image_processor = write_image_processor(
 | 
			
		||||
        hf_repo_id=args.hf_repo_id,
 | 
			
		||||
        output_dir=args.output_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if args.push_to_hub:
 | 
			
		||||
        print("Pushing to hub...")
 | 
			
		||||
        model.push_to_hub(args.hub_repo_id)
 | 
			
		||||
        image_processor.push_to_hub(args.hub_repo_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,62 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ALBERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ...utils import logging
 | 
			
		||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = AlbertConfig.from_json_file(albert_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = AlbertForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--albert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained ALBERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,389 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ALIGN checkpoints from the original repository."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import align
 | 
			
		||||
import numpy as np
 | 
			
		||||
import requests
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from tokenizer import Tokenizer
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AlignConfig,
 | 
			
		||||
    AlignModel,
 | 
			
		||||
    AlignProcessor,
 | 
			
		||||
    BertConfig,
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    EfficientNetConfig,
 | 
			
		||||
    EfficientNetImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess(image):
 | 
			
		||||
    image = tf.image.resize(image, (346, 346))
 | 
			
		||||
    image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_align_config():
 | 
			
		||||
    vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
 | 
			
		||||
    vision_config.image_size = 289
 | 
			
		||||
    vision_config.hidden_dim = 640
 | 
			
		||||
    vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
 | 
			
		||||
    vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
 | 
			
		||||
    vision_config.depthwise_padding = []
 | 
			
		||||
 | 
			
		||||
    text_config = BertConfig()
 | 
			
		||||
    config = AlignConfig.from_text_vision_configs(
 | 
			
		||||
        text_config=text_config, vision_config=vision_config, projection_dim=640
 | 
			
		||||
    )
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_processor():
 | 
			
		||||
    image_processor = EfficientNetImageProcessor(
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        rescale_factor=1 / 127.5,
 | 
			
		||||
        rescale_offset=True,
 | 
			
		||||
        do_normalize=False,
 | 
			
		||||
        include_top=False,
 | 
			
		||||
        resample=Image.BILINEAR,
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
 | 
			
		||||
    tokenizer.model_max_length = 64
 | 
			
		||||
    processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    return processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def rename_keys(original_param_names):
 | 
			
		||||
    # EfficientNet image encoder
 | 
			
		||||
    block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
 | 
			
		||||
    block_names = list(set(block_names))
 | 
			
		||||
    block_names = sorted(block_names)
 | 
			
		||||
    num_blocks = len(block_names)
 | 
			
		||||
    block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
 | 
			
		||||
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
 | 
			
		||||
    rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
 | 
			
		||||
    rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
 | 
			
		||||
    rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
 | 
			
		||||
    rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
 | 
			
		||||
 | 
			
		||||
    for b in block_names:
 | 
			
		||||
        hf_b = block_name_mapping[b]
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    key_mapping = {}
 | 
			
		||||
    for item in rename_keys:
 | 
			
		||||
        if item[0] in original_param_names:
 | 
			
		||||
            key_mapping[item[0]] = "vision_model." + item[1]
 | 
			
		||||
 | 
			
		||||
    # BERT text encoder
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    old = "tf_bert_model/bert"
 | 
			
		||||
    new = "text_model"
 | 
			
		||||
    for i in range(12):
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.query.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.query.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.key.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.key.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.value.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.value.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.dense.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.dense.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.intermediate.dense.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.intermediate.dense.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
 | 
			
		||||
 | 
			
		||||
    rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
 | 
			
		||||
    rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
 | 
			
		||||
    rename_keys.append(("dense/kernel:0", "text_projection.weight"))
 | 
			
		||||
    rename_keys.append(("dense/bias:0", "text_projection.bias"))
 | 
			
		||||
    rename_keys.append(("dense/bias:0", "text_projection.bias"))
 | 
			
		||||
    rename_keys.append(("temperature:0", "temperature"))
 | 
			
		||||
 | 
			
		||||
    for item in rename_keys:
 | 
			
		||||
        if item[0] in original_param_names:
 | 
			
		||||
            key_mapping[item[0]] = item[1]
 | 
			
		||||
    return key_mapping
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_params(hf_params, tf_params, key_mapping):
 | 
			
		||||
    list(hf_params.keys())
 | 
			
		||||
 | 
			
		||||
    for key, value in tf_params.items():
 | 
			
		||||
        if key not in key_mapping:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        hf_key = key_mapping[key]
 | 
			
		||||
        if "_conv" in key and "kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
 | 
			
		||||
        elif "embeddings" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value)
 | 
			
		||||
        elif "depthwise_kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
 | 
			
		||||
        elif "kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(np.transpose(value))
 | 
			
		||||
        elif "temperature" in key:
 | 
			
		||||
            new_hf_value = value
 | 
			
		||||
        elif "bn/gamma" in key or "bn/beta" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
 | 
			
		||||
        else:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value)
 | 
			
		||||
 | 
			
		||||
        # Replace HF parameters with original TF model parameters
 | 
			
		||||
        hf_params[hf_key].copy_(new_hf_value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ALIGN structure.
 | 
			
		||||
    """
 | 
			
		||||
    # Load original model
 | 
			
		||||
    seq_length = 64
 | 
			
		||||
    tok = Tokenizer(seq_length)
 | 
			
		||||
    original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
 | 
			
		||||
    original_model.compile()
 | 
			
		||||
    original_model.load_weights(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    tf_params = original_model.trainable_variables
 | 
			
		||||
    tf_non_train_params = original_model.non_trainable_variables
 | 
			
		||||
    tf_params = {param.name: param.numpy() for param in tf_params}
 | 
			
		||||
    for param in tf_non_train_params:
 | 
			
		||||
        tf_params[param.name] = param.numpy()
 | 
			
		||||
    tf_param_names = list(tf_params.keys())
 | 
			
		||||
 | 
			
		||||
    # Load HuggingFace model
 | 
			
		||||
    config = get_align_config()
 | 
			
		||||
    hf_model = AlignModel(config).eval()
 | 
			
		||||
    hf_params = hf_model.state_dict()
 | 
			
		||||
 | 
			
		||||
    # Create src-to-dst parameter name mapping dictionary
 | 
			
		||||
    print("Converting parameters...")
 | 
			
		||||
    key_mapping = rename_keys(tf_param_names)
 | 
			
		||||
    replace_params(hf_params, tf_params, key_mapping)
 | 
			
		||||
 | 
			
		||||
    # Initialize processor
 | 
			
		||||
    processor = get_processor()
 | 
			
		||||
    inputs = processor(
 | 
			
		||||
        images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # HF model inference
 | 
			
		||||
    hf_model.eval()
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = hf_model(**inputs)
 | 
			
		||||
 | 
			
		||||
    hf_image_features = outputs.image_embeds.detach().numpy()
 | 
			
		||||
    hf_text_features = outputs.text_embeds.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    # Original model inference
 | 
			
		||||
    original_model.trainable = False
 | 
			
		||||
    tf_image_processor = EfficientNetImageProcessor(
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        do_rescale=False,
 | 
			
		||||
        do_normalize=False,
 | 
			
		||||
        include_top=False,
 | 
			
		||||
        resample=Image.BILINEAR,
 | 
			
		||||
    )
 | 
			
		||||
    image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
 | 
			
		||||
    text = tok(tf.constant(["A picture of a cat"]))
 | 
			
		||||
 | 
			
		||||
    image_features = original_model.image_encoder(image, training=False)
 | 
			
		||||
    text_features = original_model.text_encoder(text, training=False)
 | 
			
		||||
 | 
			
		||||
    image_features = tf.nn.l2_normalize(image_features, axis=-1)
 | 
			
		||||
    text_features = tf.nn.l2_normalize(text_features, axis=-1)
 | 
			
		||||
 | 
			
		||||
    # Check whether original and HF model outputs match  -> np.allclose
 | 
			
		||||
    if not np.allclose(image_features, hf_image_features, atol=1e-3):
 | 
			
		||||
        raise ValueError("The predicted image features are not the same.")
 | 
			
		||||
    if not np.allclose(text_features, hf_text_features, atol=1e-3):
 | 
			
		||||
        raise ValueError("The predicted text features are not the same.")
 | 
			
		||||
    print("Model outputs match!")
 | 
			
		||||
 | 
			
		||||
    if save_model:
 | 
			
		||||
        # Create folder to save model
 | 
			
		||||
        if not os.path.isdir(pytorch_dump_folder_path):
 | 
			
		||||
            os.mkdir(pytorch_dump_folder_path)
 | 
			
		||||
        # Save converted model and image processor
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        # Push model and image processor to hub
 | 
			
		||||
        print("Pushing converted ALIGN to the hub...")
 | 
			
		||||
        processor.push_to_hub("align-base")
 | 
			
		||||
        hf_model.push_to_hub("align-base")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path",
 | 
			
		||||
        default="./weights/model-weights",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the pretrained TF ALIGN checkpoint.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default="hf_model",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--save_model", action="store_true", help="Save model to local")
 | 
			
		||||
    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
 | 
			
		||||
@ -1,162 +0,0 @@
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import argparse
 | 
			
		||||
import glob
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AddedToken,
 | 
			
		||||
    AriaForConditionalGeneration,
 | 
			
		||||
    AriaProcessor,
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EPILOG_TXT = """Example:
 | 
			
		||||
    python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
 | 
			
		||||
 | 
			
		||||
Example for creating the old state dict file with Python:
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
    from aria.model.language_model.aria_llama import AriaTextForCausalLM
 | 
			
		||||
 | 
			
		||||
    # load model
 | 
			
		||||
    kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
 | 
			
		||||
    model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", **kwargs)
 | 
			
		||||
 | 
			
		||||
    # load vision tower
 | 
			
		||||
    model.get_vision_tower().load_model()
 | 
			
		||||
 | 
			
		||||
    # Save state dict
 | 
			
		||||
    torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
KEYS_TO_MODIFY_MAPPING = {
 | 
			
		||||
    "vision_tower.vision_model": "vision_tower",
 | 
			
		||||
    "ln_ffn": "layer_norm",
 | 
			
		||||
    "ffn": "feed_forward",
 | 
			
		||||
    "ln_kv": "layer_norm_kv",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id):
 | 
			
		||||
    directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    for path in glob.glob(f"{directory_path}/*"):
 | 
			
		||||
        if path.endswith(".safetensors"):
 | 
			
		||||
            with safe_open(path, framework="pt", device="cpu") as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict_to_hf(state_dict):
 | 
			
		||||
    new_state_dict = {}
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        if key.endswith(".inv_freq"):
 | 
			
		||||
            continue
 | 
			
		||||
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
 | 
			
		||||
            if key_to_modify in key:
 | 
			
		||||
                key = key.replace(key_to_modify, new_key)
 | 
			
		||||
 | 
			
		||||
        new_state_dict[key] = value
 | 
			
		||||
    new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
 | 
			
		||||
    new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
 | 
			
		||||
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
 | 
			
		||||
    torch.set_default_dtype(torch.float16)
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        text_model_id,
 | 
			
		||||
        extra_special_tokens={
 | 
			
		||||
            "image_token": "<|img|>",
 | 
			
		||||
            "pad_token": "<pad>",
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
 | 
			
		||||
    tokenizer.add_special_tokens({"pad_token": "<pad>"})
 | 
			
		||||
    tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
 | 
			
		||||
 | 
			
		||||
    processor = AriaProcessor.from_pretrained(
 | 
			
		||||
        text_model_id,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(text_model_id)
 | 
			
		||||
    config.vision_config.hidden_size = 1152
 | 
			
		||||
    config.vision_config.attention_heads = 16
 | 
			
		||||
    config.pad_token_id = 2
 | 
			
		||||
    config.image_token_id = 9
 | 
			
		||||
    config.intermediate_size = config.moe_intermediate_size
 | 
			
		||||
    config.auto_map = {
 | 
			
		||||
        "AutoConfig": "modeling_aria.AriaConfig",
 | 
			
		||||
        "AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    with torch.device("meta"):
 | 
			
		||||
        model = AriaForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    state_dict = load_original_state_dict(old_state_dict_id)
 | 
			
		||||
 | 
			
		||||
    state_dict = convert_state_dict_to_hf(state_dict)
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False, assign=True)
 | 
			
		||||
 | 
			
		||||
    # print("Saving models")
 | 
			
		||||
    # model.save_pretrained("local_aria", safe_serialization=False)
 | 
			
		||||
    # processor.save_pretrained("local_aria")
 | 
			
		||||
    print("Pushing to hub")
 | 
			
		||||
    model.push_to_hub(output_hub_path, create_pr=True)
 | 
			
		||||
    processor.push_to_hub(output_hub_path, create_pr=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        epilog=EPILOG_TXT,
 | 
			
		||||
        formatter_class=argparse.RawDescriptionHelpFormatter,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--text_model_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Hub location of the text model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--vision_model_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Hub location of the vision model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_hub_path",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Location on the hub of the converted model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--old_state_dict_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,279 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torchaudio
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_audio_spectrogram_transformer_config(model_name):
 | 
			
		||||
    config = ASTConfig()
 | 
			
		||||
 | 
			
		||||
    if "10-10" in model_name:
 | 
			
		||||
        pass
 | 
			
		||||
    elif "speech-commands" in model_name:
 | 
			
		||||
        config.max_length = 128
 | 
			
		||||
    elif "12-12" in model_name:
 | 
			
		||||
        config.time_stride = 12
 | 
			
		||||
        config.frequency_stride = 12
 | 
			
		||||
    elif "14-14" in model_name:
 | 
			
		||||
        config.time_stride = 14
 | 
			
		||||
        config.frequency_stride = 14
 | 
			
		||||
    elif "16-16" in model_name:
 | 
			
		||||
        config.time_stride = 16
 | 
			
		||||
        config.frequency_stride = 16
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Model not supported")
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    if "speech-commands" in model_name:
 | 
			
		||||
        config.num_labels = 35
 | 
			
		||||
        filename = "speech-commands-v2-id2label.json"
 | 
			
		||||
    else:
 | 
			
		||||
        config.num_labels = 527
 | 
			
		||||
        filename = "audioset-id2label.json"
 | 
			
		||||
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "module.v" in name:
 | 
			
		||||
        name = name.replace("module.v", "audio_spectrogram_transformer")
 | 
			
		||||
    if "cls_token" in name:
 | 
			
		||||
        name = name.replace("cls_token", "embeddings.cls_token")
 | 
			
		||||
    if "dist_token" in name:
 | 
			
		||||
        name = name.replace("dist_token", "embeddings.distillation_token")
 | 
			
		||||
    if "pos_embed" in name:
 | 
			
		||||
        name = name.replace("pos_embed", "embeddings.position_embeddings")
 | 
			
		||||
    if "patch_embed.proj" in name:
 | 
			
		||||
        name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
 | 
			
		||||
    # transformer blocks
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "encoder.layer")
 | 
			
		||||
    if "attn.proj" in name:
 | 
			
		||||
        name = name.replace("attn.proj", "attention.output.dense")
 | 
			
		||||
    if "attn" in name:
 | 
			
		||||
        name = name.replace("attn", "attention.self")
 | 
			
		||||
    if "norm1" in name:
 | 
			
		||||
        name = name.replace("norm1", "layernorm_before")
 | 
			
		||||
    if "norm2" in name:
 | 
			
		||||
        name = name.replace("norm2", "layernorm_after")
 | 
			
		||||
    if "mlp.fc1" in name:
 | 
			
		||||
        name = name.replace("mlp.fc1", "intermediate.dense")
 | 
			
		||||
    if "mlp.fc2" in name:
 | 
			
		||||
        name = name.replace("mlp.fc2", "output.dense")
 | 
			
		||||
    # final layernorm
 | 
			
		||||
    if "audio_spectrogram_transformer.norm" in name:
 | 
			
		||||
        name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
 | 
			
		||||
    # classifier head
 | 
			
		||||
    if "module.mlp_head.0" in name:
 | 
			
		||||
        name = name.replace("module.mlp_head.0", "classifier.layernorm")
 | 
			
		||||
    if "module.mlp_head.1" in name:
 | 
			
		||||
        name = name.replace("module.mlp_head.1", "classifier.dense")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, config):
 | 
			
		||||
    for key in orig_state_dict.copy():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
 | 
			
		||||
        if "qkv" in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            layer_num = int(key_split[3])
 | 
			
		||||
            dim = config.hidden_size
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
 | 
			
		||||
                ] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
 | 
			
		||||
                ] = val[dim : dim * 2, :]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
 | 
			
		||||
                ] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
 | 
			
		||||
                ] = val[:dim]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
 | 
			
		||||
                ] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
 | 
			
		||||
                ] = val[-dim:]
 | 
			
		||||
        else:
 | 
			
		||||
            orig_state_dict[rename_key(key)] = val
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_keys(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "module.v.head.weight",
 | 
			
		||||
        "module.v.head.bias",
 | 
			
		||||
        "module.v.head_dist.weight",
 | 
			
		||||
        "module.v.head_dist.bias",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
 | 
			
		||||
    """
 | 
			
		||||
    config = get_audio_spectrogram_transformer_config(model_name)
 | 
			
		||||
 | 
			
		||||
    model_name_to_url = {
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.4593": (
 | 
			
		||||
            "https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.450": (
 | 
			
		||||
            "https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.448": (
 | 
			
		||||
            "https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.448-v2": (
 | 
			
		||||
            "https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-12-12-0.447": (
 | 
			
		||||
            "https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-14-14-0.443": (
 | 
			
		||||
            "https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-16-16-0.442": (
 | 
			
		||||
            "https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-speech-commands-v2": (
 | 
			
		||||
            "https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # load original state_dict
 | 
			
		||||
    checkpoint_url = model_name_to_url[model_name]
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
 | 
			
		||||
    # remove some keys
 | 
			
		||||
    remove_keys(state_dict)
 | 
			
		||||
    # rename some keys
 | 
			
		||||
    new_state_dict = convert_state_dict(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    # load 🤗 model
 | 
			
		||||
    model = ASTForAudioClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(new_state_dict)
 | 
			
		||||
 | 
			
		||||
    # verify outputs on dummy input
 | 
			
		||||
    # source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
 | 
			
		||||
    mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
 | 
			
		||||
    std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
 | 
			
		||||
    max_length = 1024 if "speech-commands" not in model_name else 128
 | 
			
		||||
    feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
 | 
			
		||||
 | 
			
		||||
    if "speech-commands" in model_name:
 | 
			
		||||
        # TODO: Convert dataset to Parquet
 | 
			
		||||
        dataset = load_dataset("google/speech_commands", "v0.02", split="validation")
 | 
			
		||||
        waveform = dataset[0]["audio"]["array"]
 | 
			
		||||
    else:
 | 
			
		||||
        filepath = hf_hub_download(
 | 
			
		||||
            repo_id="nielsr/audio-spectogram-transformer-checkpoint",
 | 
			
		||||
            filename="sample_audio.flac",
 | 
			
		||||
            repo_type="dataset",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        waveform, _ = torchaudio.load(filepath)
 | 
			
		||||
        waveform = waveform.squeeze().numpy()
 | 
			
		||||
 | 
			
		||||
    inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
    # forward pass
 | 
			
		||||
    outputs = model(**inputs)
 | 
			
		||||
    logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    if model_name == "ast-finetuned-audioset-10-10-0.4593":
 | 
			
		||||
        expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.450":
 | 
			
		||||
        expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.448":
 | 
			
		||||
        expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
 | 
			
		||||
        expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-12-12-0.447":
 | 
			
		||||
        expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-14-14-0.443":
 | 
			
		||||
        expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-16-16-0.442":
 | 
			
		||||
        expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
 | 
			
		||||
    elif model_name == "ast-finetuned-speech-commands-v2":
 | 
			
		||||
        expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown model name")
 | 
			
		||||
    if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
 | 
			
		||||
        raise ValueError("Logits don't match")
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        print(f"Saving feature extractor to {pytorch_dump_folder_path}")
 | 
			
		||||
        feature_extractor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print("Pushing model and feature extractor to the hub...")
 | 
			
		||||
        model.push_to_hub(f"MIT/{model_name}")
 | 
			
		||||
        feature_extractor.push_to_hub(f"MIT/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="ast-finetuned-audioset-10-10-0.4593",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,273 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from os import path
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import split_torch_state_dict_into_shards
 | 
			
		||||
from safetensors.torch import save_file
 | 
			
		||||
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
 | 
			
		||||
 | 
			
		||||
from .configuration_bamba import BambaConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict_from_mamba_ssm(original_sd: dict) -> dict[str, torch.Tensor]:
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for orig_k, param in original_sd.items():
 | 
			
		||||
        k = orig_k.replace("backbone", "model")
 | 
			
		||||
 | 
			
		||||
        # for embeddings
 | 
			
		||||
        k = k.replace("embedding", "embed_tokens")
 | 
			
		||||
 | 
			
		||||
        # for mixer
 | 
			
		||||
        k = k.replace("mixer", "mamba")
 | 
			
		||||
 | 
			
		||||
        # for final layernorm
 | 
			
		||||
        k = k.replace("norm_f", "final_layernorm")
 | 
			
		||||
 | 
			
		||||
        # for block layernorm
 | 
			
		||||
        k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
 | 
			
		||||
        k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
 | 
			
		||||
 | 
			
		||||
        # for mlp
 | 
			
		||||
        k = k.replace("mlp.fc2", "feed_forward.down_proj")
 | 
			
		||||
 | 
			
		||||
        if "mlp.fc1" in k:
 | 
			
		||||
            param, param2 = torch.chunk(param, 2, dim=0)
 | 
			
		||||
            k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
 | 
			
		||||
            state_dict[k2] = param2
 | 
			
		||||
            k = k.replace("mlp.fc1", "feed_forward.up_proj")
 | 
			
		||||
 | 
			
		||||
        if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
 | 
			
		||||
            "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
 | 
			
		||||
        ):
 | 
			
		||||
            # then this must be a mamba
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            # for attn
 | 
			
		||||
            # - because mixer was replaced to mamba above
 | 
			
		||||
            k = k.replace("mamba.out_proj", "self_attn.o_proj")
 | 
			
		||||
            if "mamba.in_proj" in k:
 | 
			
		||||
                m, n = param.shape
 | 
			
		||||
                d = (m - n) // 2
 | 
			
		||||
                param, param2, param3 = torch.split(param, [n, d, d], dim=0)
 | 
			
		||||
                k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
 | 
			
		||||
                state_dict[k2] = param2
 | 
			
		||||
                k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
 | 
			
		||||
                state_dict[k2] = param3
 | 
			
		||||
                k = k.replace("mamba.in_proj", "self_attn.q_proj")
 | 
			
		||||
 | 
			
		||||
        state_dict[k] = param
 | 
			
		||||
 | 
			
		||||
    return state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
 | 
			
		||||
def convert_ssm_config_to_hf_config(
 | 
			
		||||
    config_ssm: dict,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> BambaConfig:
 | 
			
		||||
    """Convert a config from mamba_ssm to a BambaConfig from here."""
 | 
			
		||||
    hf_config: BambaConfig = BambaConfig(**kwargs)
 | 
			
		||||
 | 
			
		||||
    hf_config.architectures = ["BambaForCausalLM"]
 | 
			
		||||
 | 
			
		||||
    # Set important values from config and recalculate other resulting entries
 | 
			
		||||
    hf_config.hidden_size = config_ssm["d_model"]
 | 
			
		||||
    hf_config.intermediate_size = config_ssm["d_intermediate"]
 | 
			
		||||
    hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
 | 
			
		||||
    hf_config.num_hidden_layers = config_ssm["n_layer"]
 | 
			
		||||
    hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
 | 
			
		||||
 | 
			
		||||
    # currently this script assumes config_ssm belongs to v2
 | 
			
		||||
    if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
 | 
			
		||||
        raise ValueError("Conversion script only supports Mamba2")
 | 
			
		||||
 | 
			
		||||
    # Set attention values
 | 
			
		||||
    attn_cfg = config_ssm.get("attn_cfg")
 | 
			
		||||
    if attn_cfg:
 | 
			
		||||
        assert attn_cfg["causal"], "Only support non-causal attention."
 | 
			
		||||
        assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
 | 
			
		||||
        assert not attn_cfg["out_proj_bias"], "Only support no out bias."
 | 
			
		||||
        hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
 | 
			
		||||
        hf_config.num_attention_heads = attn_cfg["num_heads"]
 | 
			
		||||
        hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
 | 
			
		||||
 | 
			
		||||
    attention_layer_indices = config_ssm.get("attn_layer_idx")
 | 
			
		||||
    if attention_layer_indices:
 | 
			
		||||
        hf_config.attn_layer_indices = attention_layer_indices
 | 
			
		||||
 | 
			
		||||
    # Padded vocab size, mostly of 16 but 32 is also very common in different models
 | 
			
		||||
    vocab_size = config_ssm["vocab_size"]
 | 
			
		||||
    pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
 | 
			
		||||
    if (vocab_size % pad_vocab_size_multiple) != 0:
 | 
			
		||||
        vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
 | 
			
		||||
    hf_config.vocab_size = vocab_size
 | 
			
		||||
 | 
			
		||||
    return hf_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_single_safetensor(
 | 
			
		||||
    state_dict: dict,
 | 
			
		||||
    save_directory: str,
 | 
			
		||||
    metadata: dict,
 | 
			
		||||
):
 | 
			
		||||
    save_file(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        os.path.join(save_directory, SAFE_WEIGHTS_NAME),
 | 
			
		||||
        metadata,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_sharded_safetensors(
 | 
			
		||||
    state_dict: dict,
 | 
			
		||||
    save_directory: str,
 | 
			
		||||
    metadata: dict,
 | 
			
		||||
    max_shard_size: Union[int, str] = "5GB",
 | 
			
		||||
):
 | 
			
		||||
    filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
 | 
			
		||||
        ".safetensors", "{suffix}.safetensors"
 | 
			
		||||
    )
 | 
			
		||||
    state_dict_split = split_torch_state_dict_into_shards(
 | 
			
		||||
        state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
 | 
			
		||||
    )
 | 
			
		||||
    index = {
 | 
			
		||||
        "metadata": state_dict_split.metadata,
 | 
			
		||||
        "weight_map": state_dict_split.tensor_to_filename,
 | 
			
		||||
    }
 | 
			
		||||
    # Save the index
 | 
			
		||||
    with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
 | 
			
		||||
        content = json.dumps(index, indent=2, sort_keys=True) + "\n"
 | 
			
		||||
        f.write(content)
 | 
			
		||||
 | 
			
		||||
    filename_to_tensors = state_dict_split.filename_to_tensors.items()
 | 
			
		||||
    for shard_file, tensors in filename_to_tensors:
 | 
			
		||||
        shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
 | 
			
		||||
        save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
 | 
			
		||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
 | 
			
		||||
    mamba_ssm_checkpoint_path: str,
 | 
			
		||||
    precision: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    tokenizer_path: Optional[str] = None,
 | 
			
		||||
    save_model: Union[bool, str] = True,
 | 
			
		||||
) -> None:
 | 
			
		||||
    # load tokenizer if provided, this will be used to set the
 | 
			
		||||
    # token_ids in the config file
 | 
			
		||||
    token_ids = {}
 | 
			
		||||
    if tokenizer_path:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
 | 
			
		||||
        for key in [
 | 
			
		||||
            "bos_token_id",
 | 
			
		||||
            "eos_token_id",
 | 
			
		||||
            "pad_token_id",
 | 
			
		||||
        ]:
 | 
			
		||||
            id = getattr(tokenizer, key, None)
 | 
			
		||||
            if id:
 | 
			
		||||
                token_ids[key] = id
 | 
			
		||||
 | 
			
		||||
    # there are some configs unsettable by mamba_ssn config, so
 | 
			
		||||
    # if there are changes from the defaults, have to pass them into
 | 
			
		||||
    # the function
 | 
			
		||||
    unsettables = {
 | 
			
		||||
        "mamba_d_head": 64,
 | 
			
		||||
        "mamba_d_state": 128,
 | 
			
		||||
        "mamba_n_groups": 1,
 | 
			
		||||
        "rms_norm_eps": 1e-5,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Load and save config based on name
 | 
			
		||||
    config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
 | 
			
		||||
    with open(config_path, "r", encoding="utf-8") as json_file:
 | 
			
		||||
        config = json.load(json_file)
 | 
			
		||||
 | 
			
		||||
    # convert the config
 | 
			
		||||
    hf_config = convert_ssm_config_to_hf_config(
 | 
			
		||||
        config_ssm=config,
 | 
			
		||||
        **token_ids,
 | 
			
		||||
        **unsettables,
 | 
			
		||||
    )
 | 
			
		||||
    hf_config.save_pretrained(output_dir)
 | 
			
		||||
 | 
			
		||||
    # Load state dict of the original model and transfer to hf model
 | 
			
		||||
    state_dict = torch.load(
 | 
			
		||||
        path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
 | 
			
		||||
        map_location="cpu",
 | 
			
		||||
        weights_only=True,
 | 
			
		||||
    )
 | 
			
		||||
    # FIXME: allow other parameters to pass in
 | 
			
		||||
    state_dict = convert_state_dict_from_mamba_ssm(state_dict)
 | 
			
		||||
 | 
			
		||||
    # Save new model to pytorch_dump_path
 | 
			
		||||
    dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
 | 
			
		||||
 | 
			
		||||
    save_file_fn = None
 | 
			
		||||
    if isinstance(save_model, bool) and save_model:
 | 
			
		||||
        save_file_fn = save_single_safetensor
 | 
			
		||||
    elif isinstance(save_model, str) and save_model == "sharded":
 | 
			
		||||
        save_file_fn = save_sharded_safetensors
 | 
			
		||||
 | 
			
		||||
    if save_file_fn:
 | 
			
		||||
        save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-i",
 | 
			
		||||
        "--mamba_ssm_checkpoint_directory",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-p",
 | 
			
		||||
        "--precision",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="fp16",
 | 
			
		||||
        required=True,
 | 
			
		||||
        choices=("fp32", "fp16", "bf16"),
 | 
			
		||||
        help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-t",
 | 
			
		||||
        "--tokenizer_model_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        required=False,
 | 
			
		||||
        help="Path to a the tokenizer file.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
 | 
			
		||||
        args.mamba_ssm_checkpoint_directory,
 | 
			
		||||
        args.precision,
 | 
			
		||||
        args.output_dir,
 | 
			
		||||
        save_model="sharded",
 | 
			
		||||
    )
 | 
			
		||||
@ -1,263 +0,0 @@
 | 
			
		||||
"""Convert Bark checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from bark.generation import _load_model as _bark_load_model
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import EncodecConfig, EncodecModel, set_seed
 | 
			
		||||
from transformers.models.bark.configuration_bark import (
 | 
			
		||||
    BarkCoarseConfig,
 | 
			
		||||
    BarkConfig,
 | 
			
		||||
    BarkFineConfig,
 | 
			
		||||
    BarkSemanticConfig,
 | 
			
		||||
)
 | 
			
		||||
from transformers.models.bark.generation_configuration_bark import (
 | 
			
		||||
    BarkCoarseGenerationConfig,
 | 
			
		||||
    BarkFineGenerationConfig,
 | 
			
		||||
    BarkGenerationConfig,
 | 
			
		||||
    BarkSemanticGenerationConfig,
 | 
			
		||||
)
 | 
			
		||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
set_seed(770)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
new_layer_name_dict = {
 | 
			
		||||
    "c_attn": "att_proj",
 | 
			
		||||
    "c_proj": "out_proj",
 | 
			
		||||
    "c_fc": "in_proj",
 | 
			
		||||
    "transformer.": "",
 | 
			
		||||
    "h.": "layers.",
 | 
			
		||||
    "ln_1": "layernorm_1",
 | 
			
		||||
    "ln_2": "layernorm_2",
 | 
			
		||||
    "ln_f": "layernorm_final",
 | 
			
		||||
    "wpe": "position_embeds_layer",
 | 
			
		||||
    "wte": "input_embeds_layer",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
REMOTE_MODEL_PATHS = {
 | 
			
		||||
    "text_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "text.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "coarse_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "coarse.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "fine_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "fine.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "text": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "text_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "coarse": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "coarse_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "fine": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "fine_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
 | 
			
		||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_ckpt_path(model_type, use_small=False):
 | 
			
		||||
    key = model_type
 | 
			
		||||
    if use_small:
 | 
			
		||||
        key += "_small"
 | 
			
		||||
    return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _download(from_hf_path, file_name):
 | 
			
		||||
    os.makedirs(CACHE_DIR, exist_ok=True)
 | 
			
		||||
    hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
 | 
			
		||||
    if model_type == "text":
 | 
			
		||||
        ModelClass = BarkSemanticModel
 | 
			
		||||
        ConfigClass = BarkSemanticConfig
 | 
			
		||||
        GenerationConfigClass = BarkSemanticGenerationConfig
 | 
			
		||||
    elif model_type == "coarse":
 | 
			
		||||
        ModelClass = BarkCoarseModel
 | 
			
		||||
        ConfigClass = BarkCoarseConfig
 | 
			
		||||
        GenerationConfigClass = BarkCoarseGenerationConfig
 | 
			
		||||
    elif model_type == "fine":
 | 
			
		||||
        ModelClass = BarkFineModel
 | 
			
		||||
        ConfigClass = BarkFineConfig
 | 
			
		||||
        GenerationConfigClass = BarkFineGenerationConfig
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
    model_key = f"{model_type}_small" if use_small else model_type
 | 
			
		||||
    model_info = REMOTE_MODEL_PATHS[model_key]
 | 
			
		||||
    if not os.path.exists(ckpt_path):
 | 
			
		||||
        logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
 | 
			
		||||
        _download(model_info["repo_id"], model_info["file_name"])
 | 
			
		||||
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
 | 
			
		||||
    # this is a hack
 | 
			
		||||
    model_args = checkpoint["model_args"]
 | 
			
		||||
    if "input_vocab_size" not in model_args:
 | 
			
		||||
        model_args["input_vocab_size"] = model_args["vocab_size"]
 | 
			
		||||
        model_args["output_vocab_size"] = model_args["vocab_size"]
 | 
			
		||||
        del model_args["vocab_size"]
 | 
			
		||||
 | 
			
		||||
    # convert Bark model arguments to HF Bark model arguments
 | 
			
		||||
    model_args["num_heads"] = model_args.pop("n_head")
 | 
			
		||||
    model_args["hidden_size"] = model_args.pop("n_embd")
 | 
			
		||||
    model_args["num_layers"] = model_args.pop("n_layer")
 | 
			
		||||
 | 
			
		||||
    model_config = ConfigClass(**checkpoint["model_args"])
 | 
			
		||||
    model = ModelClass(config=model_config)
 | 
			
		||||
    model_generation_config = GenerationConfigClass()
 | 
			
		||||
 | 
			
		||||
    model.generation_config = model_generation_config
 | 
			
		||||
    state_dict = checkpoint["model"]
 | 
			
		||||
    # fixup checkpoint
 | 
			
		||||
    unwanted_prefix = "_orig_mod."
 | 
			
		||||
    for k in state_dict:
 | 
			
		||||
        if k.startswith(unwanted_prefix):
 | 
			
		||||
            # replace part of the key with corresponding layer name in HF implementation
 | 
			
		||||
            new_k = k[len(unwanted_prefix) :]
 | 
			
		||||
            for old_layer_name, new_layer_name in new_layer_name_dict.items():
 | 
			
		||||
                new_k = new_k.replace(old_layer_name, new_layer_name)
 | 
			
		||||
 | 
			
		||||
            state_dict[new_k] = state_dict.pop(k)
 | 
			
		||||
 | 
			
		||||
    extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
 | 
			
		||||
    extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
 | 
			
		||||
    missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
 | 
			
		||||
    missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
 | 
			
		||||
    if len(extra_keys) != 0:
 | 
			
		||||
        raise ValueError(f"extra keys found: {extra_keys}")
 | 
			
		||||
    if len(missing_keys) != 0:
 | 
			
		||||
        raise ValueError(f"missing keys: {missing_keys}")
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
    n_params = model.num_parameters(exclude_embeddings=True)
 | 
			
		||||
    val_loss = checkpoint["best_val_loss"].item()
 | 
			
		||||
    logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    del checkpoint, state_dict
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
 | 
			
		||||
    if model_type not in ("text", "coarse", "fine"):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    device = "cpu"  # do conversion on cpu
 | 
			
		||||
 | 
			
		||||
    ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
 | 
			
		||||
    model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
 | 
			
		||||
 | 
			
		||||
    # load bark initial model
 | 
			
		||||
    bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
 | 
			
		||||
 | 
			
		||||
    if model_type == "text":
 | 
			
		||||
        bark_model = bark_model["model"]
 | 
			
		||||
 | 
			
		||||
    if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
 | 
			
		||||
        raise ValueError("initial and new models don't have the same number of parameters")
 | 
			
		||||
 | 
			
		||||
    # check if same output as the bark model
 | 
			
		||||
    batch_size = 5
 | 
			
		||||
    sequence_length = 10
 | 
			
		||||
 | 
			
		||||
    if model_type in ["text", "coarse"]:
 | 
			
		||||
        vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
 | 
			
		||||
        output_old_model = bark_model(vec)[0]
 | 
			
		||||
 | 
			
		||||
        output_new_model_total = model(vec)
 | 
			
		||||
 | 
			
		||||
        # take last logits
 | 
			
		||||
        output_new_model = output_new_model_total.logits[:, [-1], :]
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        prediction_codebook_channel = 3
 | 
			
		||||
        n_codes_total = 8
 | 
			
		||||
        vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
 | 
			
		||||
 | 
			
		||||
        output_new_model_total = model(prediction_codebook_channel, vec)
 | 
			
		||||
        output_old_model = bark_model(prediction_codebook_channel, vec)
 | 
			
		||||
 | 
			
		||||
        output_new_model = output_new_model_total.logits
 | 
			
		||||
 | 
			
		||||
    # output difference should come from the difference of self-attention implementation design
 | 
			
		||||
    if output_new_model.shape != output_old_model.shape:
 | 
			
		||||
        raise ValueError("initial and new outputs don't have the same shape")
 | 
			
		||||
    if (output_new_model - output_old_model).abs().max().item() > 1e-3:
 | 
			
		||||
        raise ValueError("initial and new outputs are not equal")
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_whole_bark_model(
 | 
			
		||||
    semantic_path,
 | 
			
		||||
    coarse_path,
 | 
			
		||||
    fine_path,
 | 
			
		||||
    append_text,
 | 
			
		||||
    hub_path,
 | 
			
		||||
    folder_path,
 | 
			
		||||
):
 | 
			
		||||
    pytorch_dump_folder_path = os.path.join(folder_path, append_text)
 | 
			
		||||
 | 
			
		||||
    semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
 | 
			
		||||
    coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
 | 
			
		||||
    fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
 | 
			
		||||
    codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
 | 
			
		||||
 | 
			
		||||
    semantic = BarkSemanticModel.from_pretrained(semantic_path)
 | 
			
		||||
    coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
 | 
			
		||||
    fineAcoustic = BarkFineModel.from_pretrained(fine_path)
 | 
			
		||||
    codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
 | 
			
		||||
 | 
			
		||||
    bark_config = BarkConfig.from_sub_model_configs(
 | 
			
		||||
        semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
 | 
			
		||||
        semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    bark = BarkModel(bark_config)
 | 
			
		||||
 | 
			
		||||
    bark.semantic = semantic
 | 
			
		||||
    bark.coarse_acoustics = coarseAcoustic
 | 
			
		||||
    bark.fine_acoustics = fineAcoustic
 | 
			
		||||
    bark.codec_model = codec
 | 
			
		||||
 | 
			
		||||
    bark.generation_config = bark_generation_config
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("model_type", type=str, help="text, coarse or fine.")
 | 
			
		||||
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
 | 
			
		||||
@ -1,156 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BART checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import fairseq
 | 
			
		||||
import torch
 | 
			
		||||
from packaging import version
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BartConfig,
 | 
			
		||||
    BartForConditionalGeneration,
 | 
			
		||||
    BartForSequenceClassification,
 | 
			
		||||
    BartModel,
 | 
			
		||||
    BartTokenizer,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
 | 
			
		||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
 | 
			
		||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
 | 
			
		||||
    raise Exception("requires fairseq >= 0.9.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
 | 
			
		||||
 | 
			
		||||
mnli_rename_keys = [
 | 
			
		||||
    ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
 | 
			
		||||
    ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
 | 
			
		||||
    ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
 | 
			
		||||
    ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_ignore_keys_(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "encoder.version",
 | 
			
		||||
        "decoder.version",
 | 
			
		||||
        "model.encoder.version",
 | 
			
		||||
        "model.decoder.version",
 | 
			
		||||
        "_float_tensor",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_xsum_checkpoint(checkpoint_path):
 | 
			
		||||
    """Checkpoint path should end in model.pt"""
 | 
			
		||||
    sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
 | 
			
		||||
    hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
 | 
			
		||||
    hub_interface.model.load_state_dict(sd["model"])
 | 
			
		||||
    return hub_interface
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_linear_from_emb(emb):
 | 
			
		||||
    vocab_size, emb_size = emb.weight.shape
 | 
			
		||||
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
 | 
			
		||||
    lin_layer.weight.data = emb.weight.data
 | 
			
		||||
    return lin_layer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BERT structure.
 | 
			
		||||
    """
 | 
			
		||||
    if not os.path.exists(checkpoint_path):
 | 
			
		||||
        bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
 | 
			
		||||
    else:
 | 
			
		||||
        bart = load_xsum_checkpoint(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    bart.model.upgrade_state_dict(bart.model.state_dict())
 | 
			
		||||
    if hf_checkpoint_name is None:
 | 
			
		||||
        hf_checkpoint_name = checkpoint_path.replace(".", "-")
 | 
			
		||||
    config = BartConfig.from_pretrained(hf_checkpoint_name)
 | 
			
		||||
    tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
 | 
			
		||||
    tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
 | 
			
		||||
    if not torch.eq(tokens, tokens2).all():
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if checkpoint_path == "bart.large.mnli":
 | 
			
		||||
        state_dict = bart.state_dict()
 | 
			
		||||
        remove_ignore_keys_(state_dict)
 | 
			
		||||
        state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
 | 
			
		||||
        for src, dest in mnli_rename_keys:
 | 
			
		||||
            rename_key(state_dict, src, dest)
 | 
			
		||||
        model = BartForSequenceClassification(config).eval()
 | 
			
		||||
        model.load_state_dict(state_dict)
 | 
			
		||||
        fairseq_output = bart.predict("mnli", tokens, return_logits=True)
 | 
			
		||||
        new_model_outputs = model(tokens)[0]  # logits
 | 
			
		||||
    else:  # no classification heads to worry about
 | 
			
		||||
        state_dict = bart.model.state_dict()
 | 
			
		||||
        remove_ignore_keys_(state_dict)
 | 
			
		||||
        state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
 | 
			
		||||
        fairseq_output = bart.extract_features(tokens)
 | 
			
		||||
        if hf_checkpoint_name == "facebook/bart-large":
 | 
			
		||||
            model = BartModel(config).eval()
 | 
			
		||||
            model.load_state_dict(state_dict)
 | 
			
		||||
            new_model_outputs = model(tokens).model[0]
 | 
			
		||||
        else:
 | 
			
		||||
            model = BartForConditionalGeneration(config).eval()  # an existing summarization ckpt
 | 
			
		||||
            model.model.load_state_dict(state_dict)
 | 
			
		||||
            if hasattr(model, "lm_head"):
 | 
			
		||||
                model.lm_head = make_linear_from_emb(model.model.shared)
 | 
			
		||||
            new_model_outputs = model.model(tokens)[0]
 | 
			
		||||
 | 
			
		||||
    # Check results
 | 
			
		||||
    if fairseq_output.shape != new_model_outputs.shape:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
 | 
			
		||||
        )
 | 
			
		||||
    if (fairseq_output != new_model_outputs).any().item():
 | 
			
		||||
        raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
 | 
			
		||||
@ -1,373 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BEiT checkpoints from the unilm repository."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BeitConfig,
 | 
			
		||||
    BeitForImageClassification,
 | 
			
		||||
    BeitForMaskedImageModeling,
 | 
			
		||||
    BeitForSemanticSegmentation,
 | 
			
		||||
    BeitImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
 | 
			
		||||
    prefix = "backbone." if is_semantic else ""
 | 
			
		||||
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
 | 
			
		||||
 | 
			
		||||
    # projection layer + position embeddings
 | 
			
		||||
    rename_keys.extend(
 | 
			
		||||
        [
 | 
			
		||||
            (f"{prefix}cls_token", "beit.embeddings.cls_token"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if has_lm_head:
 | 
			
		||||
        # mask token + shared relative position bias + layernorm
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("mask_token", "beit.embeddings.mask_token"),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_bias_table",
 | 
			
		||||
                    "beit.encoder.relative_position_bias.relative_position_bias_table",
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_index",
 | 
			
		||||
                    "beit.encoder.relative_position_bias.relative_position_index",
 | 
			
		||||
                ),
 | 
			
		||||
                ("norm.weight", "layernorm.weight"),
 | 
			
		||||
                ("norm.bias", "layernorm.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    elif is_semantic:
 | 
			
		||||
        # semantic segmentation classification heads
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
 | 
			
		||||
                ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # layernorm + classification head
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("fc_norm.weight", "beit.pooler.layernorm.weight"),
 | 
			
		||||
                ("fc_norm.bias", "beit.pooler.layernorm.bias"),
 | 
			
		||||
                ("head.weight", "classifier.weight"),
 | 
			
		||||
                ("head.bias", "classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# we split up the matrix of each encoder layer into queries, keys and values
 | 
			
		||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        prefix = "backbone." if is_semantic else ""
 | 
			
		||||
        # queries, keys and values
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
 | 
			
		||||
        q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
 | 
			
		||||
        v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
 | 
			
		||||
            : config.hidden_size, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
 | 
			
		||||
            config.hidden_size : config.hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
 | 
			
		||||
            -config.hidden_size :, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
 | 
			
		||||
 | 
			
		||||
        # gamma_1 and gamma_2
 | 
			
		||||
        # we call them lambda because otherwise they are renamed when using .from_pretrained
 | 
			
		||||
        gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
 | 
			
		||||
        gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
 | 
			
		||||
 | 
			
		||||
        # relative_position bias table + index
 | 
			
		||||
        if not has_lm_head:
 | 
			
		||||
            # each layer has its own relative position bias
 | 
			
		||||
            table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
 | 
			
		||||
            index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
 | 
			
		||||
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
 | 
			
		||||
            ] = table
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
 | 
			
		||||
            ] = index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BEiT structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define default BEiT configuration
 | 
			
		||||
    config = BeitConfig()
 | 
			
		||||
    has_lm_head = False
 | 
			
		||||
    is_semantic = False
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    # set config parameters based on URL
 | 
			
		||||
    if checkpoint_url[-9:-4] == "pt22k":
 | 
			
		||||
        # masked image modeling
 | 
			
		||||
        config.use_shared_relative_position_bias = True
 | 
			
		||||
        config.use_mask_token = True
 | 
			
		||||
        has_lm_head = True
 | 
			
		||||
    elif checkpoint_url[-9:-4] == "ft22k":
 | 
			
		||||
        # intermediate fine-tuning on ImageNet-22k
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 21841
 | 
			
		||||
        filename = "imagenet-22k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        # this dataset contains 21843 labels but the model only has 21841
 | 
			
		||||
        # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
 | 
			
		||||
        del id2label[9205]
 | 
			
		||||
        del id2label[15027]
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    elif checkpoint_url[-8:-4] == "to1k":
 | 
			
		||||
        # fine-tuning on ImageNet-1k
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 1000
 | 
			
		||||
        filename = "imagenet-1k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
        if "384" in checkpoint_url:
 | 
			
		||||
            config.image_size = 384
 | 
			
		||||
        if "512" in checkpoint_url:
 | 
			
		||||
            config.image_size = 512
 | 
			
		||||
    elif "ade20k" in checkpoint_url:
 | 
			
		||||
        # fine-tuning
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 150
 | 
			
		||||
        filename = "ade20k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
        config.image_size = 640
 | 
			
		||||
        is_semantic = True
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
 | 
			
		||||
 | 
			
		||||
    # size of the architecture
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        pass
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        config.hidden_size = 1024
 | 
			
		||||
        config.intermediate_size = 4096
 | 
			
		||||
        config.num_hidden_layers = 24
 | 
			
		||||
        config.num_attention_heads = 16
 | 
			
		||||
        if "ade20k" in checkpoint_url:
 | 
			
		||||
            config.image_size = 640
 | 
			
		||||
            config.out_indices = [7, 11, 15, 23]
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
 | 
			
		||||
 | 
			
		||||
    # load state_dict of original model, remove and rename some keys
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
 | 
			
		||||
    state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
 | 
			
		||||
 | 
			
		||||
    rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
 | 
			
		||||
    if is_semantic:
 | 
			
		||||
        # add prefix to decoder keys
 | 
			
		||||
        for key, val in state_dict.copy().items():
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            if key.startswith("backbone.fpn"):
 | 
			
		||||
                key = key.replace("backbone.fpn", "fpn")
 | 
			
		||||
            state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    if checkpoint_url[-9:-4] == "pt22k":
 | 
			
		||||
        model = BeitForMaskedImageModeling(config)
 | 
			
		||||
    elif "ade20k" in checkpoint_url:
 | 
			
		||||
        model = BeitForSemanticSegmentation(config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BeitForImageClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image
 | 
			
		||||
    if is_semantic:
 | 
			
		||||
        image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
 | 
			
		||||
        ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
 | 
			
		||||
        image = Image.open(ds[0]["file"])
 | 
			
		||||
    else:
 | 
			
		||||
        image_processor = BeitImageProcessor(
 | 
			
		||||
            size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
 | 
			
		||||
        )
 | 
			
		||||
        image = prepare_img()
 | 
			
		||||
 | 
			
		||||
    encoding = image_processor(images=image, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
    logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    expected_shape = torch.Size([1, 1000])
 | 
			
		||||
    if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 196, 8192])
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 196, 8192])
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 21841])
 | 
			
		||||
        expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
 | 
			
		||||
        expected_class_idx = 2397
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 21841])
 | 
			
		||||
        expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
 | 
			
		||||
        expected_class_idx = 2396
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
 | 
			
		||||
        expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
 | 
			
		||||
        expected_class_idx = 285
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
 | 
			
		||||
        expected_class_idx = 281
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
 | 
			
		||||
        expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
 | 
			
		||||
        expected_shape = (1, 150, 160, 160)
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
 | 
			
		||||
                [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
 | 
			
		||||
                [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
 | 
			
		||||
        expected_shape = (1, 150, 160, 160)
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
 | 
			
		||||
                [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
 | 
			
		||||
                [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Can't verify logits as model is not supported")
 | 
			
		||||
 | 
			
		||||
    if logits.shape != expected_shape:
 | 
			
		||||
        raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
 | 
			
		||||
    if not has_lm_head:
 | 
			
		||||
        if is_semantic:
 | 
			
		||||
            if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
 | 
			
		||||
                raise ValueError("First elements of logits not as expected")
 | 
			
		||||
        else:
 | 
			
		||||
            print("Predicted class idx:", logits.argmax(-1).item())
 | 
			
		||||
 | 
			
		||||
            if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
 | 
			
		||||
                raise ValueError("First elements of logits not as expected")
 | 
			
		||||
            if logits.argmax(-1).item() != expected_class_idx:
 | 
			
		||||
                raise ValueError("Predicted class index not as expected")
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    print(f"Saving model to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Saving image processor to {pytorch_dump_folder_path}")
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL to the original PyTorch checkpoint (.pth file).",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,246 +0,0 @@
 | 
			
		||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
 | 
			
		||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
 | 
			
		||||
 | 
			
		||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
 | 
			
		||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
 | 
			
		||||
 | 
			
		||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
 | 
			
		||||
 | 
			
		||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
 | 
			
		||||
      Models trained with never versions are not compatible with this script.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
 | 
			
		||||
    tf_path = os.path.abspath(tf_checkpoint_path)
 | 
			
		||||
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
 | 
			
		||||
    # Load weights from TF model
 | 
			
		||||
    init_vars = tf.train.list_variables(tf_path)
 | 
			
		||||
    names = []
 | 
			
		||||
    arrays = []
 | 
			
		||||
    layer_depth = []
 | 
			
		||||
    for full_name, shape in init_vars:
 | 
			
		||||
        # logger.info(f"Loading TF weight {name} with shape {shape}")
 | 
			
		||||
        name = full_name.split("/")
 | 
			
		||||
        if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
 | 
			
		||||
            logger.info(f"Skipping non-model layer {full_name}")
 | 
			
		||||
            continue
 | 
			
		||||
        if "optimizer" in full_name:
 | 
			
		||||
            logger.info(f"Skipping optimization layer {full_name}")
 | 
			
		||||
            continue
 | 
			
		||||
        if name[0] == "model":
 | 
			
		||||
            # ignore initial 'model'
 | 
			
		||||
            name = name[1:]
 | 
			
		||||
        # figure out how many levels deep the name is
 | 
			
		||||
        depth = 0
 | 
			
		||||
        for _name in name:
 | 
			
		||||
            if _name.startswith("layer_with_weights"):
 | 
			
		||||
                depth += 1
 | 
			
		||||
            else:
 | 
			
		||||
                break
 | 
			
		||||
        layer_depth.append(depth)
 | 
			
		||||
        # read data
 | 
			
		||||
        array = tf.train.load_variable(tf_path, full_name)
 | 
			
		||||
        names.append("/".join(name))
 | 
			
		||||
        arrays.append(array)
 | 
			
		||||
    logger.info(f"Read a total of {len(arrays):,} layers")
 | 
			
		||||
 | 
			
		||||
    # Sanity check
 | 
			
		||||
    if len(set(layer_depth)) != 1:
 | 
			
		||||
        raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
 | 
			
		||||
    layer_depth = list(set(layer_depth))[0]
 | 
			
		||||
    if layer_depth != 1:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
 | 
			
		||||
            " heads."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # convert layers
 | 
			
		||||
    logger.info("Converting weights...")
 | 
			
		||||
    for full_name, array in zip(names, arrays):
 | 
			
		||||
        name = full_name.split("/")
 | 
			
		||||
        pointer = model
 | 
			
		||||
        trace = []
 | 
			
		||||
        for i, m_name in enumerate(name):
 | 
			
		||||
            if m_name == ".ATTRIBUTES":
 | 
			
		||||
                # variable names end with .ATTRIBUTES/VARIABLE_VALUE
 | 
			
		||||
                break
 | 
			
		||||
            if m_name.startswith("layer_with_weights"):
 | 
			
		||||
                layer_num = int(m_name.split("-")[-1])
 | 
			
		||||
                if layer_num <= 2:
 | 
			
		||||
                    # embedding layers
 | 
			
		||||
                    # layer_num 0: word_embeddings
 | 
			
		||||
                    # layer_num 1: position_embeddings
 | 
			
		||||
                    # layer_num 2: token_type_embeddings
 | 
			
		||||
                    continue
 | 
			
		||||
                elif layer_num == 3:
 | 
			
		||||
                    # embedding LayerNorm
 | 
			
		||||
                    trace.extend(["embeddings", "LayerNorm"])
 | 
			
		||||
                    pointer = getattr(pointer, "embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
                elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
 | 
			
		||||
                    # encoder layers
 | 
			
		||||
                    trace.extend(["encoder", "layer", str(layer_num - 4)])
 | 
			
		||||
                    pointer = getattr(pointer, "encoder")
 | 
			
		||||
                    pointer = getattr(pointer, "layer")
 | 
			
		||||
                    pointer = pointer[layer_num - 4]
 | 
			
		||||
                elif layer_num == config.num_hidden_layers + 4:
 | 
			
		||||
                    # pooler layer
 | 
			
		||||
                    trace.extend(["pooler", "dense"])
 | 
			
		||||
                    pointer = getattr(pointer, "pooler")
 | 
			
		||||
                    pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "embeddings":
 | 
			
		||||
                trace.append("embeddings")
 | 
			
		||||
                pointer = getattr(pointer, "embeddings")
 | 
			
		||||
                if layer_num == 0:
 | 
			
		||||
                    trace.append("word_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "word_embeddings")
 | 
			
		||||
                elif layer_num == 1:
 | 
			
		||||
                    trace.append("position_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "position_embeddings")
 | 
			
		||||
                elif layer_num == 2:
 | 
			
		||||
                    trace.append("token_type_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "token_type_embeddings")
 | 
			
		||||
                else:
 | 
			
		||||
                    raise ValueError(f"Unknown embedding layer with name {full_name}")
 | 
			
		||||
                trace.append("weight")
 | 
			
		||||
                pointer = getattr(pointer, "weight")
 | 
			
		||||
            elif m_name == "_attention_layer":
 | 
			
		||||
                # self-attention layer
 | 
			
		||||
                trace.extend(["attention", "self"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "self")
 | 
			
		||||
            elif m_name == "_attention_layer_norm":
 | 
			
		||||
                # output attention norm
 | 
			
		||||
                trace.extend(["attention", "output", "LayerNorm"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
            elif m_name == "_attention_output_dense":
 | 
			
		||||
                # output attention dense
 | 
			
		||||
                trace.extend(["attention", "output", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_dense":
 | 
			
		||||
                # output dense
 | 
			
		||||
                trace.extend(["output", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_layer_norm":
 | 
			
		||||
                # output dense
 | 
			
		||||
                trace.extend(["output", "LayerNorm"])
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
            elif m_name == "_key_dense":
 | 
			
		||||
                # attention key
 | 
			
		||||
                trace.append("key")
 | 
			
		||||
                pointer = getattr(pointer, "key")
 | 
			
		||||
            elif m_name == "_query_dense":
 | 
			
		||||
                # attention query
 | 
			
		||||
                trace.append("query")
 | 
			
		||||
                pointer = getattr(pointer, "query")
 | 
			
		||||
            elif m_name == "_value_dense":
 | 
			
		||||
                # attention value
 | 
			
		||||
                trace.append("value")
 | 
			
		||||
                pointer = getattr(pointer, "value")
 | 
			
		||||
            elif m_name == "_intermediate_dense":
 | 
			
		||||
                # attention intermediate dense
 | 
			
		||||
                trace.extend(["intermediate", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "intermediate")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_layer_norm":
 | 
			
		||||
                # output layer norm
 | 
			
		||||
                trace.append("output")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
            # weights & biases
 | 
			
		||||
            elif m_name in ["bias", "beta"]:
 | 
			
		||||
                trace.append("bias")
 | 
			
		||||
                pointer = getattr(pointer, "bias")
 | 
			
		||||
            elif m_name in ["kernel", "gamma"]:
 | 
			
		||||
                trace.append("weight")
 | 
			
		||||
                pointer = getattr(pointer, "weight")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Ignored {m_name}")
 | 
			
		||||
        # for certain layers reshape is necessary
 | 
			
		||||
        trace = ".".join(trace)
 | 
			
		||||
        if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
 | 
			
		||||
            r"(\S+)\.attention\.output\.dense\.weight", trace
 | 
			
		||||
        ):
 | 
			
		||||
            array = array.reshape(pointer.data.shape)
 | 
			
		||||
        if "kernel" in full_name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
        if pointer.shape == array.shape:
 | 
			
		||||
            pointer.data = torch.from_numpy(array)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
 | 
			
		||||
                f" {array.shape}"
 | 
			
		||||
            )
 | 
			
		||||
        logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
 | 
			
		||||
    # Instantiate model
 | 
			
		||||
    logger.info(f"Loading model based on config from {config_path}...")
 | 
			
		||||
    config = BertConfig.from_json_file(config_path)
 | 
			
		||||
    model = BertModel(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from checkpoint
 | 
			
		||||
    logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
 | 
			
		||||
    load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model (must include filename).",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,62 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = BertConfig.from_json_file(bert_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = BertForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained BERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,112 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        model: BertModel Pytorch model instance to be converted
 | 
			
		||||
        ckpt_dir: Tensorflow model directory
 | 
			
		||||
        model_name: model name
 | 
			
		||||
 | 
			
		||||
    Currently supported HF models:
 | 
			
		||||
 | 
			
		||||
        - Y BertModel
 | 
			
		||||
        - N BertForMaskedLM
 | 
			
		||||
        - N BertForPreTraining
 | 
			
		||||
        - N BertForMultipleChoice
 | 
			
		||||
        - N BertForNextSentencePrediction
 | 
			
		||||
        - N BertForSequenceClassification
 | 
			
		||||
        - N BertForQuestionAnswering
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
 | 
			
		||||
 | 
			
		||||
    var_map = (
 | 
			
		||||
        ("layer.", "layer_"),
 | 
			
		||||
        ("word_embeddings.weight", "word_embeddings"),
 | 
			
		||||
        ("position_embeddings.weight", "position_embeddings"),
 | 
			
		||||
        ("token_type_embeddings.weight", "token_type_embeddings"),
 | 
			
		||||
        (".", "/"),
 | 
			
		||||
        ("LayerNorm/weight", "LayerNorm/gamma"),
 | 
			
		||||
        ("LayerNorm/bias", "LayerNorm/beta"),
 | 
			
		||||
        ("weight", "kernel"),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not os.path.isdir(ckpt_dir):
 | 
			
		||||
        os.makedirs(ckpt_dir)
 | 
			
		||||
 | 
			
		||||
    state_dict = model.state_dict()
 | 
			
		||||
 | 
			
		||||
    def to_tf_var_name(name: str):
 | 
			
		||||
        for patt, repl in iter(var_map):
 | 
			
		||||
            name = name.replace(patt, repl)
 | 
			
		||||
        return f"bert/{name}"
 | 
			
		||||
 | 
			
		||||
    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
 | 
			
		||||
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
 | 
			
		||||
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
 | 
			
		||||
        session.run(tf.variables_initializer([tf_var]))
 | 
			
		||||
        session.run(tf_var)
 | 
			
		||||
        return tf_var
 | 
			
		||||
 | 
			
		||||
    tf.reset_default_graph()
 | 
			
		||||
    with tf.Session() as session:
 | 
			
		||||
        for var_name in state_dict:
 | 
			
		||||
            tf_name = to_tf_var_name(var_name)
 | 
			
		||||
            torch_tensor = state_dict[var_name].numpy()
 | 
			
		||||
            if any(x in var_name for x in tensors_to_transpose):
 | 
			
		||||
                torch_tensor = torch_tensor.T
 | 
			
		||||
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
 | 
			
		||||
            tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
 | 
			
		||||
            tf_weight = session.run(tf_var)
 | 
			
		||||
            print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
 | 
			
		||||
 | 
			
		||||
        saver = tf.train.Saver(tf.trainable_variables())
 | 
			
		||||
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(raw_args=None):
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
 | 
			
		||||
    parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
 | 
			
		||||
    args = parser.parse_args(raw_args)
 | 
			
		||||
 | 
			
		||||
    model = BertModel.from_pretrained(
 | 
			
		||||
        pretrained_model_name_or_path=args.model_name,
 | 
			
		||||
        state_dict=torch.load(args.pytorch_model_path, weights_only=True),
 | 
			
		||||
        cache_dir=args.cache_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,188 +0,0 @@
 | 
			
		||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
 | 
			
		||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
 | 
			
		||||
 | 
			
		||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertForMaskedLM
 | 
			
		||||
from transformers.models.bert.modeling_bert import (
 | 
			
		||||
    BertIntermediate,
 | 
			
		||||
    BertLayer,
 | 
			
		||||
    BertOutput,
 | 
			
		||||
    BertPooler,
 | 
			
		||||
    BertSelfAttention,
 | 
			
		||||
    BertSelfOutput,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
 | 
			
		||||
    def get_masked_lm_array(name: str):
 | 
			
		||||
        full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_array(name: str):
 | 
			
		||||
        full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_layer_array(layer_index: int, name: str):
 | 
			
		||||
        full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape):
 | 
			
		||||
        full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
        array = array.reshape(original_shape)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    print(f"Loading model based on config from {config_path}...")
 | 
			
		||||
    config = BertConfig.from_json_file(config_path)
 | 
			
		||||
    model = BertForMaskedLM(config)
 | 
			
		||||
 | 
			
		||||
    # Layers
 | 
			
		||||
    for layer_index in range(0, config.num_hidden_layers):
 | 
			
		||||
        layer: BertLayer = model.bert.encoder.layer[layer_index]
 | 
			
		||||
 | 
			
		||||
        # Self-attention
 | 
			
		||||
        self_attn: BertSelfAttention = layer.attention.self
 | 
			
		||||
 | 
			
		||||
        self_attn.query.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.query.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.key.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.key.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Self-attention Output
 | 
			
		||||
        self_output: BertSelfOutput = layer.attention.output
 | 
			
		||||
 | 
			
		||||
        self_output.dense.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_output.dense.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
 | 
			
		||||
        self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
        # Intermediate
 | 
			
		||||
        intermediate: BertIntermediate = layer.intermediate
 | 
			
		||||
 | 
			
		||||
        intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
 | 
			
		||||
        intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
 | 
			
		||||
 | 
			
		||||
        # Output
 | 
			
		||||
        bert_output: BertOutput = layer.output
 | 
			
		||||
 | 
			
		||||
        bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
 | 
			
		||||
        bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
 | 
			
		||||
 | 
			
		||||
        bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
 | 
			
		||||
        bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
    # Embeddings
 | 
			
		||||
    model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
 | 
			
		||||
    model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
 | 
			
		||||
    model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
 | 
			
		||||
    model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
 | 
			
		||||
 | 
			
		||||
    # LM Head
 | 
			
		||||
    lm_head = model.cls.predictions.transform
 | 
			
		||||
 | 
			
		||||
    lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
 | 
			
		||||
    lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
 | 
			
		||||
 | 
			
		||||
    lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
 | 
			
		||||
    lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
    model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
 | 
			
		||||
 | 
			
		||||
    # Pooling
 | 
			
		||||
    model.bert.pooler = BertPooler(config=config)
 | 
			
		||||
    model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
 | 
			
		||||
    model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
 | 
			
		||||
 | 
			
		||||
    # Export final model
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    # Integration test - should load without any errors ;)
 | 
			
		||||
    new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
 | 
			
		||||
    print(new_model.eval())
 | 
			
		||||
 | 
			
		||||
    print("Model conversion was done successfully!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,69 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BigBird checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = BigBirdConfig.from_json_file(big_bird_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
 | 
			
		||||
    if is_trivia_qa:
 | 
			
		||||
        model = BigBirdForQuestionAnswering(config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BigBirdForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--big_bird_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained BERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(
 | 
			
		||||
        args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
 | 
			
		||||
    )
 | 
			
		||||
@ -1,169 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
INIT_COMMON = [
 | 
			
		||||
    # tf -> hf
 | 
			
		||||
    ("/", "."),
 | 
			
		||||
    ("layer_", "layers."),
 | 
			
		||||
    ("kernel", "weight"),
 | 
			
		||||
    ("beta", "bias"),
 | 
			
		||||
    ("gamma", "weight"),
 | 
			
		||||
    ("pegasus", "model"),
 | 
			
		||||
]
 | 
			
		||||
END_COMMON = [
 | 
			
		||||
    (".output.dense", ".fc2"),
 | 
			
		||||
    ("intermediate.LayerNorm", "final_layer_norm"),
 | 
			
		||||
    ("intermediate.dense", "fc1"),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
DECODER_PATTERNS = (
 | 
			
		||||
    INIT_COMMON
 | 
			
		||||
    + [
 | 
			
		||||
        ("attention.self.LayerNorm", "self_attn_layer_norm"),
 | 
			
		||||
        ("attention.output.dense", "self_attn.out_proj"),
 | 
			
		||||
        ("attention.self", "self_attn"),
 | 
			
		||||
        ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
 | 
			
		||||
        ("attention.encdec_output.dense", "encoder_attn.out_proj"),
 | 
			
		||||
        ("attention.encdec", "encoder_attn"),
 | 
			
		||||
        ("key", "k_proj"),
 | 
			
		||||
        ("value", "v_proj"),
 | 
			
		||||
        ("query", "q_proj"),
 | 
			
		||||
        ("decoder.LayerNorm", "decoder.layernorm_embedding"),
 | 
			
		||||
    ]
 | 
			
		||||
    + END_COMMON
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
REMAINING_PATTERNS = (
 | 
			
		||||
    INIT_COMMON
 | 
			
		||||
    + [
 | 
			
		||||
        ("embeddings.word_embeddings", "shared.weight"),
 | 
			
		||||
        ("embeddings.position_embeddings", "embed_positions.weight"),
 | 
			
		||||
        ("attention.self.LayerNorm", "self_attn_layer_norm"),
 | 
			
		||||
        ("attention.output.dense", "self_attn.output"),
 | 
			
		||||
        ("attention.self", "self_attn.self"),
 | 
			
		||||
        ("encoder.LayerNorm", "encoder.layernorm_embedding"),
 | 
			
		||||
    ]
 | 
			
		||||
    + END_COMMON
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
KEYS_TO_IGNORE = [
 | 
			
		||||
    "encdec/key/bias",
 | 
			
		||||
    "encdec/query/bias",
 | 
			
		||||
    "encdec/value/bias",
 | 
			
		||||
    "self/key/bias",
 | 
			
		||||
    "self/query/bias",
 | 
			
		||||
    "self/value/bias",
 | 
			
		||||
    "encdec_output/dense/bias",
 | 
			
		||||
    "attention/output/dense/bias",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_key(k, patterns):
 | 
			
		||||
    for tf_name, hf_name in patterns:
 | 
			
		||||
        k = k.replace(tf_name, hf_name)
 | 
			
		||||
    return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
 | 
			
		||||
    cfg = BigBirdPegasusConfig(**config_update)
 | 
			
		||||
    torch_model = BigBirdPegasusForConditionalGeneration(cfg)
 | 
			
		||||
    state_dict = torch_model.state_dict()
 | 
			
		||||
    mapping = {}
 | 
			
		||||
 | 
			
		||||
    # separating decoder weights
 | 
			
		||||
    decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
 | 
			
		||||
    remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
 | 
			
		||||
 | 
			
		||||
    for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
 | 
			
		||||
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
 | 
			
		||||
        if any(conditions):
 | 
			
		||||
            continue
 | 
			
		||||
        patterns = DECODER_PATTERNS
 | 
			
		||||
        new_k = rename_state_dict_key(k, patterns)
 | 
			
		||||
        if new_k not in state_dict:
 | 
			
		||||
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
 | 
			
		||||
        if any(i in k for i in ["dense", "query", "key", "value"]):
 | 
			
		||||
            v = v.T
 | 
			
		||||
        mapping[new_k] = torch.from_numpy(v)
 | 
			
		||||
        assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
 | 
			
		||||
 | 
			
		||||
    for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
 | 
			
		||||
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
 | 
			
		||||
        if any(conditions):
 | 
			
		||||
            continue
 | 
			
		||||
        patterns = REMAINING_PATTERNS
 | 
			
		||||
        new_k = rename_state_dict_key(k, patterns)
 | 
			
		||||
        if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
 | 
			
		||||
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
 | 
			
		||||
        if any(i in k for i in ["dense", "query", "key", "value"]):
 | 
			
		||||
            v = v.T
 | 
			
		||||
        mapping[new_k] = torch.from_numpy(v)
 | 
			
		||||
        if k != "pegasus/embeddings/position_embeddings":
 | 
			
		||||
            assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
 | 
			
		||||
 | 
			
		||||
    mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
 | 
			
		||||
    mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
 | 
			
		||||
    missing, extra = torch_model.load_state_dict(mapping, strict=False)
 | 
			
		||||
    unexpected_missing = [
 | 
			
		||||
        k
 | 
			
		||||
        for k in missing
 | 
			
		||||
        if k
 | 
			
		||||
        not in [
 | 
			
		||||
            "final_logits_bias",
 | 
			
		||||
            "model.encoder.embed_tokens.weight",
 | 
			
		||||
            "model.decoder.embed_tokens.weight",
 | 
			
		||||
            "lm_head.weight",
 | 
			
		||||
        ]
 | 
			
		||||
    ]
 | 
			
		||||
    assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
 | 
			
		||||
    assert extra == [], f"no matches found for the following tf keys {extra}"
 | 
			
		||||
    return torch_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_tf_weights_as_numpy(path) -> dict:
 | 
			
		||||
    init_vars = tf.train.list_variables(path)
 | 
			
		||||
    tf_weights = {}
 | 
			
		||||
    ignore_name = ["global_step"]
 | 
			
		||||
    for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
 | 
			
		||||
        skip_key = any(pat in name for pat in ignore_name)
 | 
			
		||||
        if skip_key:
 | 
			
		||||
            continue
 | 
			
		||||
        array = tf.train.load_variable(path, name)
 | 
			
		||||
        tf_weights[name] = array
 | 
			
		||||
    return tf_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
 | 
			
		||||
    tf_weights = get_tf_weights_as_numpy(ckpt_path)
 | 
			
		||||
    torch_model = convert_bigbird_pegasus(tf_weights, config_update)
 | 
			
		||||
    torch_model.save_pretrained(save_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
 | 
			
		||||
    parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    config_update = {}
 | 
			
		||||
    convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
 | 
			
		||||
@ -1,292 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BioGptConfig, BioGptForCausalLM
 | 
			
		||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
 | 
			
		||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
 | 
			
		||||
from transformers.utils import WEIGHTS_NAME, logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_warning()
 | 
			
		||||
 | 
			
		||||
json_indent = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
 | 
			
		||||
class Dictionary:
 | 
			
		||||
    """A mapping from symbols to consecutive integers"""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # begin keyword-only arguments
 | 
			
		||||
        bos="<s>",
 | 
			
		||||
        pad="<pad>",
 | 
			
		||||
        eos="</s>",
 | 
			
		||||
        unk="<unk>",
 | 
			
		||||
        extra_special_symbols=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
 | 
			
		||||
        self.symbols = []
 | 
			
		||||
        self.count = []
 | 
			
		||||
        self.indices = {}
 | 
			
		||||
        self.bos_index = self.add_symbol(bos)
 | 
			
		||||
        self.pad_index = self.add_symbol(pad)
 | 
			
		||||
        self.eos_index = self.add_symbol(eos)
 | 
			
		||||
        self.unk_index = self.add_symbol(unk)
 | 
			
		||||
        if extra_special_symbols:
 | 
			
		||||
            for s in extra_special_symbols:
 | 
			
		||||
                self.add_symbol(s)
 | 
			
		||||
        self.nspecial = len(self.symbols)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        return self.indices == other.indices
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        if idx < len(self.symbols):
 | 
			
		||||
            return self.symbols[idx]
 | 
			
		||||
        return self.unk_word
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """Returns the number of symbols in the dictionary"""
 | 
			
		||||
        return len(self.symbols)
 | 
			
		||||
 | 
			
		||||
    def __contains__(self, sym):
 | 
			
		||||
        return sym in self.indices
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load(cls, f):
 | 
			
		||||
        """Loads the dictionary from a text file with the format:
 | 
			
		||||
 | 
			
		||||
        ```
 | 
			
		||||
        <symbol0> <count0>
 | 
			
		||||
        <symbol1> <count1>
 | 
			
		||||
        ...
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
        d = cls()
 | 
			
		||||
        d.add_from_file(f)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def add_symbol(self, word, n=1, overwrite=False):
 | 
			
		||||
        """Adds a word to the dictionary"""
 | 
			
		||||
        if word in self.indices and not overwrite:
 | 
			
		||||
            idx = self.indices[word]
 | 
			
		||||
            self.count[idx] = self.count[idx] + n
 | 
			
		||||
            return idx
 | 
			
		||||
        else:
 | 
			
		||||
            idx = len(self.symbols)
 | 
			
		||||
            self.indices[word] = idx
 | 
			
		||||
            self.symbols.append(word)
 | 
			
		||||
            self.count.append(n)
 | 
			
		||||
            return idx
 | 
			
		||||
 | 
			
		||||
    def _load_meta(self, lines):
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def add_from_file(self, f):
 | 
			
		||||
        """
 | 
			
		||||
        Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(f, str):
 | 
			
		||||
            try:
 | 
			
		||||
                with open(f, "r", encoding="utf-8") as fd:
 | 
			
		||||
                    self.add_from_file(fd)
 | 
			
		||||
            except FileNotFoundError as fnfe:
 | 
			
		||||
                raise fnfe
 | 
			
		||||
            except UnicodeError:
 | 
			
		||||
                raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        lines = f.readlines()
 | 
			
		||||
        indices_start_line = self._load_meta(lines)
 | 
			
		||||
 | 
			
		||||
        for line in lines[indices_start_line:]:
 | 
			
		||||
            try:
 | 
			
		||||
                line, field = line.rstrip().rsplit(" ", 1)
 | 
			
		||||
                if field == "#fairseq:overwrite":
 | 
			
		||||
                    overwrite = True
 | 
			
		||||
                    line, field = line.rsplit(" ", 1)
 | 
			
		||||
                else:
 | 
			
		||||
                    overwrite = False
 | 
			
		||||
                count = int(field)
 | 
			
		||||
                word = line
 | 
			
		||||
                if word in self and not overwrite:
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        f"Duplicate word found when loading Dictionary: '{word}'. "
 | 
			
		||||
                        "Duplicate words can overwrite earlier ones by adding the "
 | 
			
		||||
                        "#fairseq:overwrite flag at the end of the corresponding row "
 | 
			
		||||
                        "in the dictionary file. If using the Camembert model, please "
 | 
			
		||||
                        "download an updated copy of the model file."
 | 
			
		||||
                    )
 | 
			
		||||
                self.add_symbol(word, n=count, overwrite=overwrite)
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rewrite_dict_keys(d):
 | 
			
		||||
    # (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
 | 
			
		||||
    # e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
 | 
			
		||||
    d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
 | 
			
		||||
    keep_keys = "<s> <pad> </s> <unk>".split()
 | 
			
		||||
    # restore the special tokens
 | 
			
		||||
    for k in keep_keys:
 | 
			
		||||
        del d2[f"{k}</w>"]
 | 
			
		||||
        d2[k] = d[k]  # restore
 | 
			
		||||
    return d2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
 | 
			
		||||
    # prep
 | 
			
		||||
    if not os.path.exists(biogpt_checkpoint_path):
 | 
			
		||||
        raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
 | 
			
		||||
    os.makedirs(pytorch_dump_folder_path, exist_ok=True)
 | 
			
		||||
    print(f"Writing results to {pytorch_dump_folder_path}")
 | 
			
		||||
 | 
			
		||||
    # handle various types of models
 | 
			
		||||
 | 
			
		||||
    checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
 | 
			
		||||
    if not os.path.isfile(checkpoint_file):
 | 
			
		||||
        raise ValueError(f"path to the file {checkpoint_file} does not exist!")
 | 
			
		||||
    chkpt = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
 | 
			
		||||
 | 
			
		||||
    args = chkpt["cfg"]["model"]
 | 
			
		||||
 | 
			
		||||
    # dicts
 | 
			
		||||
    dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
 | 
			
		||||
    if not os.path.isfile(dict_file):
 | 
			
		||||
        raise ValueError(f"path to the file {dict_file} does not exist!")
 | 
			
		||||
    src_dict = Dictionary.load(dict_file)
 | 
			
		||||
    src_vocab = rewrite_dict_keys(src_dict.indices)
 | 
			
		||||
    src_vocab_size = len(src_vocab)
 | 
			
		||||
    src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
 | 
			
		||||
    print(f"Generating {src_vocab_file} of {src_vocab_size} records")
 | 
			
		||||
    with open(src_vocab_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # merges_file (bpecodes)
 | 
			
		||||
    bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
 | 
			
		||||
    if not os.path.isfile(bpecodes_file):
 | 
			
		||||
        raise ValueError(f"path to the file {bpecodes_file} does not exist!")
 | 
			
		||||
 | 
			
		||||
    merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
 | 
			
		||||
    shutil.copyfile(bpecodes_file, merges_file)
 | 
			
		||||
 | 
			
		||||
    # model config
 | 
			
		||||
    biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
 | 
			
		||||
 | 
			
		||||
    model_conf = {
 | 
			
		||||
        "activation_dropout": args["activation_dropout"],
 | 
			
		||||
        "architectures": ["BioGptForCausalLM"],
 | 
			
		||||
        "attention_probs_dropout_prob": args["attention_dropout"],
 | 
			
		||||
        "bos_token_id": 0,
 | 
			
		||||
        "eos_token_id": 2,
 | 
			
		||||
        "hidden_act": args["activation_fn"],
 | 
			
		||||
        "hidden_dropout_prob": args["dropout"],
 | 
			
		||||
        "hidden_size": args["decoder_embed_dim"],
 | 
			
		||||
        "initializer_range": 0.02,
 | 
			
		||||
        "intermediate_size": args["decoder_ffn_embed_dim"],
 | 
			
		||||
        "layer_norm_eps": 1e-12,
 | 
			
		||||
        "layerdrop": args["decoder_layerdrop"],
 | 
			
		||||
        "max_position_embeddings": args["max_target_positions"],
 | 
			
		||||
        "model_type": "biogpt",
 | 
			
		||||
        "num_attention_heads": args["decoder_attention_heads"],
 | 
			
		||||
        "num_hidden_layers": args["decoder_layers"],
 | 
			
		||||
        "pad_token_id": 1,
 | 
			
		||||
        "scale_embedding": not args["no_scale_embedding"],
 | 
			
		||||
        "tie_word_embeddings": args["share_decoder_input_output_embed"],
 | 
			
		||||
        "vocab_size": src_vocab_size,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # good hparam defaults to start with
 | 
			
		||||
 | 
			
		||||
    print(f"Generating {biogpt_model_config_file}")
 | 
			
		||||
    with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # tokenizer config
 | 
			
		||||
    biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
 | 
			
		||||
 | 
			
		||||
    tokenizer_conf = {
 | 
			
		||||
        "bos_token": "<s>",
 | 
			
		||||
        "eos_token": "</s>",
 | 
			
		||||
        "model_max_length": 1024,
 | 
			
		||||
        "pad_token": "<pad>",
 | 
			
		||||
        "special_tokens_map_file": None,
 | 
			
		||||
        "tokenizer_class": "BioGptTokenizer",
 | 
			
		||||
        "unk_token": "<unk>",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    print(f"Generating {biogpt_tokenizer_config_file}")
 | 
			
		||||
    with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # model
 | 
			
		||||
    model_state_dict = chkpt["model"]
 | 
			
		||||
 | 
			
		||||
    # remove unneeded keys
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "decoder.version",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        model_state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
    layer_names = list(model_state_dict.keys())
 | 
			
		||||
    for layer_name in layer_names:
 | 
			
		||||
        if layer_name.endswith("output_projection.weight"):
 | 
			
		||||
            model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
 | 
			
		||||
        else:
 | 
			
		||||
            model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
 | 
			
		||||
 | 
			
		||||
    config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    model_new = BioGptForCausalLM(config)
 | 
			
		||||
 | 
			
		||||
    # check that it loads ok
 | 
			
		||||
    model_new.load_state_dict(model_state_dict)
 | 
			
		||||
 | 
			
		||||
    # save
 | 
			
		||||
    pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
 | 
			
		||||
    print(f"Generating {pytorch_weights_dump_path}")
 | 
			
		||||
    torch.save(model_state_dict, pytorch_weights_dump_path)
 | 
			
		||||
 | 
			
		||||
    print("Conversion is done!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--biogpt_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
 | 
			
		||||
            " bpecodes, etc."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,177 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BiT checkpoints from the timm library."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from timm import create_model
 | 
			
		||||
from timm.data import resolve_data_config
 | 
			
		||||
from timm.data.transforms_factory import create_transform
 | 
			
		||||
 | 
			
		||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_config(model_name):
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "imagenet-1k-id2label.json"
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    conv_layer = "std_conv" if "bit" in model_name else False
 | 
			
		||||
 | 
			
		||||
    # note that when using BiT as backbone for ViT-hybrid checkpoints,
 | 
			
		||||
    # one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
 | 
			
		||||
    # config.conv_layer = "std_conv_same"
 | 
			
		||||
    config = BitConfig(
 | 
			
		||||
        conv_layer=conv_layer,
 | 
			
		||||
        num_labels=1000,
 | 
			
		||||
        id2label=id2label,
 | 
			
		||||
        label2id=label2id,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "stem.conv" in name:
 | 
			
		||||
        name = name.replace("stem.conv", "bit.embedder.convolution")
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "layers")
 | 
			
		||||
    if "head.fc" in name:
 | 
			
		||||
        name = name.replace("head.fc", "classifier.1")
 | 
			
		||||
    if name.startswith("norm"):
 | 
			
		||||
        name = "bit." + name
 | 
			
		||||
    if "bit" not in name and "classifier" not in name:
 | 
			
		||||
        name = "bit.encoder." + name
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BiT structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define default BiT configuration
 | 
			
		||||
    config = get_config(model_name)
 | 
			
		||||
 | 
			
		||||
    # load original model from timm
 | 
			
		||||
    timm_model = create_model(model_name, pretrained=True)
 | 
			
		||||
    timm_model.eval()
 | 
			
		||||
 | 
			
		||||
    # load state_dict of original model
 | 
			
		||||
    state_dict = timm_model.state_dict()
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = BitForImageClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    # create image processor
 | 
			
		||||
    transform = create_transform(**resolve_data_config({}, model=timm_model))
 | 
			
		||||
    timm_transforms = transform.transforms
 | 
			
		||||
 | 
			
		||||
    pillow_resamplings = {
 | 
			
		||||
        "bilinear": PILImageResampling.BILINEAR,
 | 
			
		||||
        "bicubic": PILImageResampling.BICUBIC,
 | 
			
		||||
        "nearest": PILImageResampling.NEAREST,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    processor = BitImageProcessor(
 | 
			
		||||
        do_resize=True,
 | 
			
		||||
        size={"shortest_edge": timm_transforms[0].size},
 | 
			
		||||
        resample=pillow_resamplings[timm_transforms[0].interpolation.value],
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
 | 
			
		||||
        do_normalize=True,
 | 
			
		||||
        image_mean=timm_transforms[-1].mean.tolist(),
 | 
			
		||||
        image_std=timm_transforms[-1].std.tolist(),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    image = prepare_img()
 | 
			
		||||
    timm_pixel_values = transform(image).unsqueeze(0)
 | 
			
		||||
    pixel_values = processor(image, return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    # verify pixel values
 | 
			
		||||
    assert torch.allclose(timm_pixel_values, pixel_values)
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = model(pixel_values)
 | 
			
		||||
        logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    print("Logits:", logits[0, :3])
 | 
			
		||||
    print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
 | 
			
		||||
    timm_logits = timm_model(pixel_values)
 | 
			
		||||
    assert timm_logits.shape == outputs.logits.shape
 | 
			
		||||
    assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing model {model_name} and processor to the hub")
 | 
			
		||||
        model.push_to_hub(f"ybelkada/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"ybelkada/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="resnetv2_50x1_bitm",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the BiT timm model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to push the model to the hub.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,114 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Blenderbot checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
PATTERNS = [
 | 
			
		||||
    ["attention", "attn"],
 | 
			
		||||
    ["encoder_attention", "encoder_attn"],
 | 
			
		||||
    ["q_lin", "q_proj"],
 | 
			
		||||
    ["k_lin", "k_proj"],
 | 
			
		||||
    ["v_lin", "v_proj"],
 | 
			
		||||
    ["out_lin", "out_proj"],
 | 
			
		||||
    ["norm_embeddings", "layernorm_embedding"],
 | 
			
		||||
    ["position_embeddings", "embed_positions"],
 | 
			
		||||
    ["embeddings", "embed_tokens"],
 | 
			
		||||
    ["ffn.lin", "fc"],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_key(k):
 | 
			
		||||
    if k == "embeddings.weight":
 | 
			
		||||
        return "shared.weight"
 | 
			
		||||
 | 
			
		||||
    for parlai_name, hf_name in PATTERNS:
 | 
			
		||||
        k = k.replace(parlai_name, hf_name)
 | 
			
		||||
 | 
			
		||||
    if k.startswith("encoder"):
 | 
			
		||||
        k = k.replace(".attn", ".self_attn")
 | 
			
		||||
        k = k.replace("norm1", "self_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm2", "final_layer_norm")
 | 
			
		||||
    elif k.startswith("decoder"):
 | 
			
		||||
        k = k.replace("norm1", "self_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm2", "encoder_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm3", "final_layer_norm")
 | 
			
		||||
    return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_layernorm_keys(sd):
 | 
			
		||||
    keys = [
 | 
			
		||||
        "model.encoder.layernorm_embedding.weight",
 | 
			
		||||
        "model.encoder.layernorm_embedding.bias",
 | 
			
		||||
        "model.decoder.layernorm_embedding.weight",
 | 
			
		||||
        "model.decoder.layernorm_embedding.bias",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in keys:
 | 
			
		||||
        v = sd.pop(k)
 | 
			
		||||
        new_k = k.replace("layernorm_embedding", "layer_norm")
 | 
			
		||||
        assert new_k not in sd
 | 
			
		||||
        sd[new_k] = v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
IGNORE_KEYS = ["START"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BERT structure.
 | 
			
		||||
    """
 | 
			
		||||
    model = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
 | 
			
		||||
    sd = model["model"]
 | 
			
		||||
    cfg = BlenderbotConfig.from_json_file(config_json_path)
 | 
			
		||||
    m = BlenderbotForConditionalGeneration(cfg)
 | 
			
		||||
    valid_keys = m.model.state_dict().keys()
 | 
			
		||||
    failures = []
 | 
			
		||||
    mapping = {}
 | 
			
		||||
    for k, v in sd.items():
 | 
			
		||||
        if k in IGNORE_KEYS:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        new_k = rename_state_dict_key(k)
 | 
			
		||||
        if new_k not in valid_keys:
 | 
			
		||||
            failures.append([k, new_k])
 | 
			
		||||
        else:
 | 
			
		||||
            mapping[new_k] = v
 | 
			
		||||
    if cfg.normalize_before:  # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
 | 
			
		||||
        rename_layernorm_keys(sd)
 | 
			
		||||
    m.model.load_state_dict(mapping, strict=True)
 | 
			
		||||
    m.half()
 | 
			
		||||
    m.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
 | 
			
		||||
    parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
 | 
			
		||||
@ -1,191 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# git clone https://github.com/salesforce/BLIP.git
 | 
			
		||||
from models.blip import blip_decoder
 | 
			
		||||
from models.blip_itm import blip_itm
 | 
			
		||||
from models.blip_vqa import blip_vqa
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from torchvision.transforms.functional import InterpolationMode
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    BlipConfig,
 | 
			
		||||
    BlipForConditionalGeneration,
 | 
			
		||||
    BlipForImageTextRetrieval,
 | 
			
		||||
    BlipForQuestionAnswering,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_demo_image(image_size, device):
 | 
			
		||||
    img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
 | 
			
		||||
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
 | 
			
		||||
 | 
			
		||||
    transform = transforms.Compose(
 | 
			
		||||
        [
 | 
			
		||||
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    image = transform(raw_image).unsqueeze(0).to(device)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(key):
 | 
			
		||||
    if "visual_encoder" in key:
 | 
			
		||||
        key = re.sub("visual_encoder*", "vision_model.encoder", key)
 | 
			
		||||
    if "blocks" in key:
 | 
			
		||||
        key = re.sub(r"blocks", "layers", key)
 | 
			
		||||
    if "attn" in key:
 | 
			
		||||
        key = re.sub(r"attn", "self_attn", key)
 | 
			
		||||
    if "norm1" in key:
 | 
			
		||||
        key = re.sub(r"norm1", "layer_norm1", key)
 | 
			
		||||
    if "norm2" in key:
 | 
			
		||||
        key = re.sub(r"norm2", "layer_norm2", key)
 | 
			
		||||
    if "encoder.norm" in key:
 | 
			
		||||
        key = re.sub(r"encoder.norm", "post_layernorm", key)
 | 
			
		||||
    if "encoder.patch_embed.proj" in key:
 | 
			
		||||
        key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
 | 
			
		||||
 | 
			
		||||
    if "encoder.pos_embed" in key:
 | 
			
		||||
        key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
 | 
			
		||||
    if "encoder.cls_token" in key:
 | 
			
		||||
        key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
 | 
			
		||||
 | 
			
		||||
    if "self_attn" in key:
 | 
			
		||||
        key = re.sub(r"self_attn.proj", "self_attn.projection", key)
 | 
			
		||||
 | 
			
		||||
    return key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if config_path is not None:
 | 
			
		||||
        config = BlipConfig.from_pretrained(config_path)
 | 
			
		||||
    else:
 | 
			
		||||
        config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
 | 
			
		||||
 | 
			
		||||
    hf_model = BlipForConditionalGeneration(config).eval()
 | 
			
		||||
 | 
			
		||||
    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
 | 
			
		||||
 | 
			
		||||
    pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
 | 
			
		||||
    pt_model = pt_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = pt_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_model.load_state_dict(modified_state_dict)
 | 
			
		||||
 | 
			
		||||
    image_size = 384
 | 
			
		||||
    image = load_demo_image(image_size=image_size, device="cpu")
 | 
			
		||||
    tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
 | 
			
		||||
    input_ids = tokenizer(["a picture of"]).input_ids
 | 
			
		||||
 | 
			
		||||
    out = hf_model.generate(image, input_ids)
 | 
			
		||||
 | 
			
		||||
    assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
 | 
			
		||||
 | 
			
		||||
    out = hf_model.generate(image)
 | 
			
		||||
 | 
			
		||||
    assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
 | 
			
		||||
    model_url = (
 | 
			
		||||
        "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
 | 
			
		||||
    vqa_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = vqa_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_vqa_model = BlipForQuestionAnswering(config)
 | 
			
		||||
 | 
			
		||||
    hf_vqa_model.load_state_dict(modified_state_dict)
 | 
			
		||||
 | 
			
		||||
    question = ["How many dogs are in this image?"]
 | 
			
		||||
    question_input_ids = tokenizer(question, return_tensors="pt").input_ids
 | 
			
		||||
 | 
			
		||||
    answer = hf_vqa_model.generate(question_input_ids, image)
 | 
			
		||||
    print(tokenizer.decode(answer[0]))
 | 
			
		||||
 | 
			
		||||
    assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
 | 
			
		||||
 | 
			
		||||
    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
 | 
			
		||||
 | 
			
		||||
    itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
 | 
			
		||||
    itm_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = itm_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_itm_model = BlipForImageTextRetrieval(config)
 | 
			
		||||
 | 
			
		||||
    question = ["A picture of a woman with a dog sitting in a beach"]
 | 
			
		||||
    question_input_ids = tokenizer(
 | 
			
		||||
        question,
 | 
			
		||||
        return_tensors="pt",
 | 
			
		||||
        padding="max_length",
 | 
			
		||||
        truncation=True,
 | 
			
		||||
        max_length=35,
 | 
			
		||||
    ).input_ids
 | 
			
		||||
 | 
			
		||||
    hf_itm_model.load_state_dict(modified_state_dict)
 | 
			
		||||
    hf_itm_model.eval()
 | 
			
		||||
 | 
			
		||||
    out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
 | 
			
		||||
    out = hf_itm_model(question_input_ids, image, use_itm_head=False)
 | 
			
		||||
 | 
			
		||||
    assert out[0].item() == 0.2110687494277954
 | 
			
		||||
    assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
@ -1,390 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
Convert BLIP-2 checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# pip3 install salesforce-lavis
 | 
			
		||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
 | 
			
		||||
# to make sure we can compare both original and HF implementation in float32
 | 
			
		||||
from lavis.models import load_model_and_preprocess
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    Blip2Config,
 | 
			
		||||
    Blip2ForConditionalGeneration,
 | 
			
		||||
    Blip2ForImageTextRetrieval,
 | 
			
		||||
    Blip2Processor,
 | 
			
		||||
    Blip2QFormerConfig,
 | 
			
		||||
    Blip2VisionConfig,
 | 
			
		||||
    BlipImageProcessor,
 | 
			
		||||
    OPTConfig,
 | 
			
		||||
    T5Config,
 | 
			
		||||
    set_seed,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_demo_image():
 | 
			
		||||
    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
 | 
			
		||||
    image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
 | 
			
		||||
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config, model_name):
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    # fmt: off
 | 
			
		||||
 | 
			
		||||
    # vision encoder
 | 
			
		||||
    rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
 | 
			
		||||
    rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
 | 
			
		||||
    rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
 | 
			
		||||
 | 
			
		||||
    for i in range(config.vision_config.num_hidden_layers):
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
 | 
			
		||||
 | 
			
		||||
    # QFormer
 | 
			
		||||
    rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
 | 
			
		||||
    rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
 | 
			
		||||
        rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
 | 
			
		||||
        rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
 | 
			
		||||
        rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
 | 
			
		||||
        rename_keys.append(("text_proj.weight", "text_projection.weight"))
 | 
			
		||||
        rename_keys.append(("text_proj.bias", "text_projection.bias"))
 | 
			
		||||
 | 
			
		||||
    # fmt: on
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_v_bias(state_dict, config):
 | 
			
		||||
    for i in range(config.vision_config.num_hidden_layers):
 | 
			
		||||
        # read in original q and v biases
 | 
			
		||||
        q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
 | 
			
		||||
        v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
 | 
			
		||||
 | 
			
		||||
        # next, set bias in the state dict
 | 
			
		||||
        qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
 | 
			
		||||
        state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_blip2_config(model_name, eos_token_id):
 | 
			
		||||
    image_size = 364 if "coco" in model_name else 224
 | 
			
		||||
    vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
 | 
			
		||||
 | 
			
		||||
    # make sure the models have proper bos_token_id and eos_token_id set (important for generation)
 | 
			
		||||
    # seems like flan-T5 models don't have bos_token_id properly set?
 | 
			
		||||
    if "opt-2.7b" in model_name:
 | 
			
		||||
        text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
 | 
			
		||||
    elif "opt-6.7b" in model_name:
 | 
			
		||||
        text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
 | 
			
		||||
    elif "t5-xl" in model_name:
 | 
			
		||||
        text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
 | 
			
		||||
    elif "t5-xxl" in model_name:
 | 
			
		||||
        text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
 | 
			
		||||
    elif "itm" in model_name:
 | 
			
		||||
        text_config = {}
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Model name not supported")
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        config = Blip2Config(
 | 
			
		||||
            vision_config=vision_config,
 | 
			
		||||
            qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        config = Blip2Config(vision_config=vision_config, text_config=text_config)
 | 
			
		||||
 | 
			
		||||
    return config, image_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_blip2_checkpoint(
 | 
			
		||||
    model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to Transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if "opt" in model_name:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
 | 
			
		||||
    elif "itm" in model_name:
 | 
			
		||||
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
 | 
			
		||||
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
 | 
			
		||||
    else:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        eos_token_id = None
 | 
			
		||||
    else:
 | 
			
		||||
        eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
 | 
			
		||||
    config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        hf_model = Blip2ForImageTextRetrieval(config).eval()
 | 
			
		||||
    else:
 | 
			
		||||
        hf_model = Blip2ForConditionalGeneration(config).eval()
 | 
			
		||||
 | 
			
		||||
    model_name_to_original = {
 | 
			
		||||
        "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
 | 
			
		||||
        "blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
 | 
			
		||||
        "blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
 | 
			
		||||
        "blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
 | 
			
		||||
        "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
 | 
			
		||||
        "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
 | 
			
		||||
        "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
 | 
			
		||||
        "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
 | 
			
		||||
        "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    name, type = model_name_to_original[model_name]
 | 
			
		||||
 | 
			
		||||
    # load original model
 | 
			
		||||
    print("Loading original model...")
 | 
			
		||||
    original_model, vis_processors, _ = load_model_and_preprocess(
 | 
			
		||||
        name=name, model_type=type, is_eval=True, device=lavis_device
 | 
			
		||||
    )
 | 
			
		||||
    original_model.eval()
 | 
			
		||||
    print("Done!")
 | 
			
		||||
 | 
			
		||||
    # update state dict keys
 | 
			
		||||
    state_dict = original_model.state_dict()
 | 
			
		||||
    rename_keys = create_rename_keys(config, model_name)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
 | 
			
		||||
    # some keys can be renamed efficiently
 | 
			
		||||
    for key, val in state_dict.copy().items():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if key.startswith("Qformer.bert"):
 | 
			
		||||
            key = key.replace("Qformer.bert", "qformer")
 | 
			
		||||
        if "attention.self" in key:
 | 
			
		||||
            key = key.replace("self", "attention")
 | 
			
		||||
        if "opt_proj" in key:
 | 
			
		||||
            key = key.replace("opt_proj", "language_projection")
 | 
			
		||||
        if "t5_proj" in key:
 | 
			
		||||
            key = key.replace("t5_proj", "language_projection")
 | 
			
		||||
        if key.startswith("opt"):
 | 
			
		||||
            key = key.replace("opt", "language")
 | 
			
		||||
        if key.startswith("t5"):
 | 
			
		||||
            key = key.replace("t5", "language")
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # read in qv biases
 | 
			
		||||
    read_in_q_v_bias(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
    assert len(missing_keys) == 0
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
 | 
			
		||||
        assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
 | 
			
		||||
    else:
 | 
			
		||||
        assert unexpected_keys == ["qformer.embeddings.position_ids"]
 | 
			
		||||
 | 
			
		||||
    image = load_demo_image()
 | 
			
		||||
    original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
 | 
			
		||||
 | 
			
		||||
    # create processor
 | 
			
		||||
    image_processor = BlipImageProcessor(
 | 
			
		||||
        size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
 | 
			
		||||
    )
 | 
			
		||||
    processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
    # make sure processor creates exact same pixel values
 | 
			
		||||
    assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
 | 
			
		||||
 | 
			
		||||
    original_model.to(lavis_device)
 | 
			
		||||
    hf_model.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        caption = "a large fountain spewing water into the air"
 | 
			
		||||
        input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
        attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            original_logits = original_model(
 | 
			
		||||
                {"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
 | 
			
		||||
            )
 | 
			
		||||
            logits = hf_model(
 | 
			
		||||
                pixel_values=pixel_values,
 | 
			
		||||
                input_ids=input_ids,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                use_image_text_matching_head=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.logits_per_image.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3])
 | 
			
		||||
        print("First values of HF logits:", logits.logits_per_image[0, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        # cast to same type
 | 
			
		||||
        target_dtype = logits.logits_per_image.dtype
 | 
			
		||||
        assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
        original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
 | 
			
		||||
        itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
 | 
			
		||||
        assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            original_logits = original_model(
 | 
			
		||||
                {"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
 | 
			
		||||
            )
 | 
			
		||||
            logits = hf_model(
 | 
			
		||||
                pixel_values=pixel_values,
 | 
			
		||||
                input_ids=input_ids,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                use_image_text_matching_head=False,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.logits_per_image.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3])
 | 
			
		||||
        print("First values of HF logits:", logits.logits_per_image[0, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        # cast to same type
 | 
			
		||||
        target_dtype = logits.logits_per_image.dtype
 | 
			
		||||
        assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            if "opt" in model_name:
 | 
			
		||||
                original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
 | 
			
		||||
                logits = hf_model(pixel_values, input_ids).logits
 | 
			
		||||
            else:
 | 
			
		||||
                original_logits = original_model(
 | 
			
		||||
                    {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
 | 
			
		||||
                ).logits
 | 
			
		||||
                labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
 | 
			
		||||
                logits = hf_model(pixel_values, input_ids, labels=labels).logits
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3, :3])
 | 
			
		||||
        print("First values of HF logits:", logits[0, :3, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
        print("Generating a caption...")
 | 
			
		||||
        prompt = "Question: what object is in this image? Answer:"
 | 
			
		||||
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        set_seed(42)
 | 
			
		||||
 | 
			
		||||
        original_outputs = original_model.generate(
 | 
			
		||||
            {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
 | 
			
		||||
        )
 | 
			
		||||
        outputs = hf_model.generate(
 | 
			
		||||
            pixel_values,
 | 
			
		||||
            input_ids,
 | 
			
		||||
            do_sample=True,
 | 
			
		||||
            num_beams=5,
 | 
			
		||||
            max_length=30,
 | 
			
		||||
            min_length=1,
 | 
			
		||||
            top_p=0.9,
 | 
			
		||||
            repetition_penalty=1.0,
 | 
			
		||||
            length_penalty=1.0,
 | 
			
		||||
            temperature=1,
 | 
			
		||||
        )
 | 
			
		||||
        output_text = processor.batch_decode(outputs, skip_special_tokens=True)
 | 
			
		||||
        output_text = [text.strip() for text in output_text]
 | 
			
		||||
        print("Original generation:", original_outputs)
 | 
			
		||||
        print("HF generation:", output_text)
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        processor.push_to_hub(f"nielsr/{model_name}")
 | 
			
		||||
        hf_model.push_to_hub(f"nielsr/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    choices = [
 | 
			
		||||
        "blip2-opt-2.7b",
 | 
			
		||||
        "blip2-opt-6.7b",
 | 
			
		||||
        "blip2-opt-2.7b-coco",
 | 
			
		||||
        "blip2-opt-6.7b-coco",
 | 
			
		||||
        "blip2-flan-t5-xl",
 | 
			
		||||
        "blip2-flan-t5-xl-coco",
 | 
			
		||||
        "blip2-flan-t5-xxl",
 | 
			
		||||
        "blip2-itm-vit-g",
 | 
			
		||||
        "blip2-itm-vit-g-coco",
 | 
			
		||||
    ]
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="blip2-opt-2.7b",
 | 
			
		||||
        choices=choices,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to hf config.json of model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to push the model and processor to the hub after converting",
 | 
			
		||||
    )
 | 
			
		||||
    # note: this script is tested on 2 GPUs, as models are compared in float32,
 | 
			
		||||
    # which requires quite some memory. Hence loading both on a
 | 
			
		||||
    # separate device is the easiest to compare
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_blip2_checkpoint(
 | 
			
		||||
        args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
 | 
			
		||||
    )
 | 
			
		||||
@ -1,254 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BigScience BLOOM checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BloomConfig, BloomModel
 | 
			
		||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
 | 
			
		||||
    "word_embeddings_layernorm.weight",
 | 
			
		||||
    "word_embeddings_layernorm.bias",
 | 
			
		||||
    "input_layernorm.weight",
 | 
			
		||||
    "input_layernorm.bias",
 | 
			
		||||
    "post_attention_layernorm.weight",
 | 
			
		||||
    "post_attention_layernorm.bias",
 | 
			
		||||
    "self_attention.dense.bias",
 | 
			
		||||
    "mlp.dense_4h_to_h.bias",
 | 
			
		||||
    "ln_f.weight",
 | 
			
		||||
    "ln_f.bias",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
 | 
			
		||||
    "mlp.dense_4h_to_h.weight",
 | 
			
		||||
    "self_attention.dense.weight",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_name_mapping(key, file):
 | 
			
		||||
    """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
 | 
			
		||||
    # Handle first and last layers
 | 
			
		||||
    layer_rename_map = {
 | 
			
		||||
        "word_embeddings.weight": "word_embeddings.weight",
 | 
			
		||||
        "word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
 | 
			
		||||
        "word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
 | 
			
		||||
        "weight": "ln_f.weight",
 | 
			
		||||
        "bias": "ln_f.bias",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if key in layer_rename_map:
 | 
			
		||||
        return layer_rename_map[key]
 | 
			
		||||
 | 
			
		||||
    # Handle transformer blocks
 | 
			
		||||
    layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
 | 
			
		||||
    layer_number -= 3
 | 
			
		||||
    return f"h.{layer_number}." + key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dtype_size(dtype):
 | 
			
		||||
    if dtype == torch.bool:
 | 
			
		||||
        return 1 / 8
 | 
			
		||||
    bit_search = re.search(r"[^\d](\d+)$", str(dtype))
 | 
			
		||||
    if bit_search is None:
 | 
			
		||||
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
 | 
			
		||||
    bit_size = int(bit_search.groups()[0])
 | 
			
		||||
    return bit_size // 8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bloom_checkpoint_to_pytorch(
 | 
			
		||||
    bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
 | 
			
		||||
):
 | 
			
		||||
    # Construct model
 | 
			
		||||
    if bloom_config_file == "":
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
    else:
 | 
			
		||||
        config = BloomConfig.from_json_file(bloom_config_file)
 | 
			
		||||
 | 
			
		||||
    if shard_model:
 | 
			
		||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
			
		||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
			
		||||
 | 
			
		||||
        index_dict = {"weight_map": {}, "metadata": {}}
 | 
			
		||||
        total_size = 0
 | 
			
		||||
 | 
			
		||||
        missing_keys = None
 | 
			
		||||
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
 | 
			
		||||
        for j, file in enumerate(file_names):
 | 
			
		||||
            print(f"Processing file: {file}")
 | 
			
		||||
            tensors = None
 | 
			
		||||
 | 
			
		||||
            for i in range(pretraining_tp):
 | 
			
		||||
                # load all TP files
 | 
			
		||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
			
		||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
 | 
			
		||||
 | 
			
		||||
                # Rename keys in the transformers names
 | 
			
		||||
                keys = list(temp.keys())
 | 
			
		||||
                for key in keys:
 | 
			
		||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
			
		||||
 | 
			
		||||
                if tensors is None:
 | 
			
		||||
                    tensors = temp
 | 
			
		||||
                else:
 | 
			
		||||
                    for key in tensors:
 | 
			
		||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                            # We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
			
		||||
                            tensors[key] += temp[key]
 | 
			
		||||
                        else:
 | 
			
		||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
			
		||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
			
		||||
                            # We concatenate these weights across TP ranks
 | 
			
		||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
			
		||||
 | 
			
		||||
            # Divide by the number of TP the weights we want to average
 | 
			
		||||
            for key in tensors:
 | 
			
		||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
			
		||||
            torch.save(
 | 
			
		||||
                tensors,
 | 
			
		||||
                os.path.join(
 | 
			
		||||
                    pytorch_dump_folder_path,
 | 
			
		||||
                    f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin",
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            for key in tensors:
 | 
			
		||||
                value = tensors[key]
 | 
			
		||||
                total_size += value.numel() * get_dtype_size(value.dtype)
 | 
			
		||||
                if key not in index_dict["weight_map"]:
 | 
			
		||||
                    index_dict["weight_map"][key] = (
 | 
			
		||||
                        f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
			
		||||
        index_dict["metadata"]["total_size"] = total_size
 | 
			
		||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(config.to_json_string())
 | 
			
		||||
        with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
 | 
			
		||||
            json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
 | 
			
		||||
            f.write(json_config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BloomModel(config)
 | 
			
		||||
 | 
			
		||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
			
		||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
			
		||||
 | 
			
		||||
        missing_keys = None
 | 
			
		||||
        for i, file in enumerate(file_names):
 | 
			
		||||
            tensors = None
 | 
			
		||||
            for i in range(pretraining_tp):
 | 
			
		||||
                # load all TP files
 | 
			
		||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
			
		||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
 | 
			
		||||
 | 
			
		||||
                # Rename keys in the transformers names
 | 
			
		||||
                keys = list(temp.keys())
 | 
			
		||||
                for key in keys:
 | 
			
		||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
			
		||||
 | 
			
		||||
                if tensors is None:
 | 
			
		||||
                    tensors = temp
 | 
			
		||||
                else:
 | 
			
		||||
                    for key in tensors:
 | 
			
		||||
                        # We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
			
		||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                            tensors[key] += temp[key]
 | 
			
		||||
                        else:
 | 
			
		||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
			
		||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
			
		||||
                            # We concatenate these weights across TP ranks
 | 
			
		||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
			
		||||
 | 
			
		||||
            # Divide by the number of TP the weights we want to average
 | 
			
		||||
            for key in tensors:
 | 
			
		||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
			
		||||
 | 
			
		||||
            other_keys = model.load_state_dict(tensors, strict=False)
 | 
			
		||||
            assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
 | 
			
		||||
            if missing_keys is None:
 | 
			
		||||
                missing_keys = set(other_keys.missing_keys)
 | 
			
		||||
            else:
 | 
			
		||||
                missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
 | 
			
		||||
 | 
			
		||||
        assert not missing_keys, f"The keys {missing_keys} are missing"
 | 
			
		||||
 | 
			
		||||
        # Save pytorch-model
 | 
			
		||||
        os.makedirs(pytorch_dump_folder_path, exist_ok=True)
 | 
			
		||||
        pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
 | 
			
		||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
			
		||||
        print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
 | 
			
		||||
        if config.torch_dtype is not None:
 | 
			
		||||
            model = model.to(config.torch_dtype)
 | 
			
		||||
        torch.save(model.state_dict(), pytorch_weights_dump_path)
 | 
			
		||||
        print(f"Save configuration file to {pytorch_config_dump_path}")
 | 
			
		||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(config.to_json_string())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bloom_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the Megatron-LM checkpoint path.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bloom_config_file",
 | 
			
		||||
        default="",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help=(
 | 
			
		||||
            "An optional config json file corresponding to the pre-trained model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--shard_model",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pretraining_tp",
 | 
			
		||||
        default=4,
 | 
			
		||||
        type=int,
 | 
			
		||||
        help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bloom_checkpoint_to_pytorch(
 | 
			
		||||
        args.bloom_checkpoint_path,
 | 
			
		||||
        args.bloom_config_file,
 | 
			
		||||
        args.pytorch_dump_folder_path,
 | 
			
		||||
        args.shard_model,
 | 
			
		||||
        args.pretraining_tp,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,145 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Bros checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import bros  # original repo
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BrosConfig, BrosModel, BrosProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_configs(model_name):
 | 
			
		||||
    bros_config = BrosConfig.from_pretrained(model_name)
 | 
			
		||||
    return bros_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_ignore_keys_(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "embeddings.bbox_sinusoid_emb.inv_freq",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if name == "embeddings.bbox_projection.weight":
 | 
			
		||||
        name = "bbox_embeddings.bbox_projection.weight"
 | 
			
		||||
 | 
			
		||||
    if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
 | 
			
		||||
        name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
 | 
			
		||||
 | 
			
		||||
    if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
 | 
			
		||||
        name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, model):
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in orig_state_dict.copy():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
        orig_state_dict[rename_key(key)] = val
 | 
			
		||||
 | 
			
		||||
    # remove ignore keys
 | 
			
		||||
    remove_ignore_keys_(orig_state_dict)
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
 | 
			
		||||
    # load original model
 | 
			
		||||
    original_model = bros.BrosModel.from_pretrained(model_name).eval()
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace Model
 | 
			
		||||
    bros_config = get_configs(model_name)
 | 
			
		||||
    model = BrosModel.from_pretrained(model_name, config=bros_config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    state_dict = original_model.state_dict()
 | 
			
		||||
    new_state_dict = convert_state_dict(state_dict, model)
 | 
			
		||||
    model.load_state_dict(new_state_dict)
 | 
			
		||||
 | 
			
		||||
    # verify results
 | 
			
		||||
 | 
			
		||||
    # original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
 | 
			
		||||
    bbox = torch.tensor(
 | 
			
		||||
        [
 | 
			
		||||
            [
 | 
			
		||||
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
 | 
			
		||||
                [0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
 | 
			
		||||
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
 | 
			
		||||
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
 | 
			
		||||
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
 | 
			
		||||
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
 | 
			
		||||
                [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    processor = BrosProcessor.from_pretrained(model_name)
 | 
			
		||||
 | 
			
		||||
    encoding = processor("His name is Rocco.", return_tensors="pt")
 | 
			
		||||
    encoding["bbox"] = bbox
 | 
			
		||||
 | 
			
		||||
    original_hidden_states = original_model(**encoding).last_hidden_state
 | 
			
		||||
    # pixel_values = processor(image, return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    last_hidden_states = model(**encoding).last_hidden_state
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        print(f"Saving model and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
 | 
			
		||||
        processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="jinho8345/bros-base-uncased",
 | 
			
		||||
        required=False,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the original model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        required=False,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether or not to push the converted model and processor to the 🤗 hub.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,59 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert T5 checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = T5Config.from_json_file(config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = T5ForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,65 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert CANINE checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
 | 
			
		||||
    # Initialize PyTorch model
 | 
			
		||||
    config = CanineConfig()
 | 
			
		||||
    model = CanineModel(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_canine(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model (weights and configuration)
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    # Save tokenizer files
 | 
			
		||||
    tokenizer = CanineTokenizer()
 | 
			
		||||
    print(f"Save tokenizer files to {pytorch_dump_path}")
 | 
			
		||||
    tokenizer.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to a folder where the PyTorch model will be placed.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
 | 
			
		||||
@ -1,478 +0,0 @@
 | 
			
		||||
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import argparse
 | 
			
		||||
import gc
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
import yaml
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    ChameleonConfig,
 | 
			
		||||
    ChameleonForConditionalGeneration,
 | 
			
		||||
    ChameleonImageProcessor,
 | 
			
		||||
    ChameleonProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers import LlamaTokenizerFast
 | 
			
		||||
except ImportError:
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        "Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
 | 
			
		||||
        "Update your `tokenizers` library and re-run the tokenizer conversion."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Sample usage:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
 | 
			
		||||
    --input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Thereafter, models can be loaded via:
 | 
			
		||||
 | 
			
		||||
```py
 | 
			
		||||
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
 | 
			
		||||
 | 
			
		||||
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
 | 
			
		||||
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
 | 
			
		||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
NUM_SHARDS = {
 | 
			
		||||
    "7B": 1,
 | 
			
		||||
    "30B": 4,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
VOCAB_SIZE = 65536
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
 | 
			
		||||
    return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_json(path):
 | 
			
		||||
    with open(path, "r") as f:
 | 
			
		||||
        return json.load(f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_json(text, path):
 | 
			
		||||
    with open(path, "w") as f:
 | 
			
		||||
        json.dump(text, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
 | 
			
		||||
    os.makedirs(model_path, exist_ok=True)
 | 
			
		||||
    input_model_path = os.path.join(input_base_path, "models", model_size.lower())
 | 
			
		||||
    params_path = os.path.join(input_model_path, "params.json")
 | 
			
		||||
    consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
 | 
			
		||||
 | 
			
		||||
    params = read_json(params_path)
 | 
			
		||||
    if os.path.isfile(consolidate_params_path):
 | 
			
		||||
        params = {**params, **read_json(consolidate_params_path)}
 | 
			
		||||
    num_shards = NUM_SHARDS[model_size]
 | 
			
		||||
    model_parallel_size = params["model_parallel_size"]
 | 
			
		||||
    params = params.get("model", params)
 | 
			
		||||
    n_layers = params["n_layers"]
 | 
			
		||||
    n_heads = params["n_heads"]
 | 
			
		||||
    n_heads_per_shard = n_heads // num_shards
 | 
			
		||||
    dim = params["dim"]
 | 
			
		||||
    dims_per_head = dim // n_heads
 | 
			
		||||
    base = params.get("rope_theta", 10000.0)
 | 
			
		||||
    swin_norm = params["swin_norm"]
 | 
			
		||||
    if base > 10000.0:
 | 
			
		||||
        max_position_embeddings = 16384
 | 
			
		||||
    else:
 | 
			
		||||
        # Depending on the Chameleon version, the default max_position_embeddings has different values.
 | 
			
		||||
        if chameleon_version == 1:
 | 
			
		||||
            max_position_embeddings = 4096
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                f"Version {chameleon_version} of chameleon is not supported yet. "
 | 
			
		||||
                "Current supported versions of chameleon are [1]."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if params.get("n_kv_heads", None) is not None:
 | 
			
		||||
        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
 | 
			
		||||
        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
 | 
			
		||||
        key_value_dim = dim // num_key_value_heads
 | 
			
		||||
    else:  # compatibility with other checkpoints
 | 
			
		||||
        num_key_value_heads = n_heads
 | 
			
		||||
        num_local_key_value_heads = n_heads_per_shard
 | 
			
		||||
        key_value_dim = dim
 | 
			
		||||
 | 
			
		||||
    print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
 | 
			
		||||
    # Load weights
 | 
			
		||||
    if num_shards == 1:
 | 
			
		||||
        # Not sharded
 | 
			
		||||
        # (The sharded implementation would also work, but this is simpler.)
 | 
			
		||||
        loaded = None
 | 
			
		||||
        for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
 | 
			
		||||
            possible_path = os.path.join(input_model_path, possible_name)
 | 
			
		||||
            if os.path.exists(possible_path):
 | 
			
		||||
                loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
 | 
			
		||||
                break
 | 
			
		||||
        assert loaded is not None
 | 
			
		||||
    else:
 | 
			
		||||
        # Sharded
 | 
			
		||||
        loaded = [
 | 
			
		||||
            torch.load(
 | 
			
		||||
                os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
 | 
			
		||||
            )
 | 
			
		||||
            for i in range(num_shards)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    # permute for sliced rotary
 | 
			
		||||
    def permute(w, n_heads, dim1=dim, dim2=dim):
 | 
			
		||||
        return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
 | 
			
		||||
 | 
			
		||||
    # Load weights to the state dict
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
    for layer_i in range(n_layers):
 | 
			
		||||
        if num_shards == 1:
 | 
			
		||||
            # Unsharded
 | 
			
		||||
            state_dict.update(
 | 
			
		||||
                {
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
 | 
			
		||||
                        loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
 | 
			
		||||
                    ),
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
 | 
			
		||||
                        loaded[f"layers.{layer_i}.attention.wk.weight"],
 | 
			
		||||
                        n_heads=num_key_value_heads,
 | 
			
		||||
                        dim1=key_value_dim,
 | 
			
		||||
                    ),
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.input_layernorm.weight": loaded[
 | 
			
		||||
                        f"layers.{layer_i}.attention_norm.weight"
 | 
			
		||||
                    ],
 | 
			
		||||
                    f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
 | 
			
		||||
                        f"layers.{layer_i}.ffn_norm.weight"
 | 
			
		||||
                    ],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(n_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(n_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            # Sharded
 | 
			
		||||
            state_dict.update(
 | 
			
		||||
                {
 | 
			
		||||
                    f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
 | 
			
		||||
                        [l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
 | 
			
		||||
                    ).mean(dim=0),
 | 
			
		||||
                    f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
 | 
			
		||||
                        [l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
 | 
			
		||||
                    ).mean(dim=0),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
 | 
			
		||||
                torch.cat(
 | 
			
		||||
                    [
 | 
			
		||||
                        loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
 | 
			
		||||
                        for i in range(num_shards)
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                ).reshape(dim, dim),
 | 
			
		||||
                n_heads=n_heads,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
 | 
			
		||||
                torch.cat(
 | 
			
		||||
                    [
 | 
			
		||||
                        loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
 | 
			
		||||
                            num_local_key_value_heads, dims_per_head, dim
 | 
			
		||||
                        )
 | 
			
		||||
                        for i in range(num_shards)
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                ).reshape(key_value_dim, dim),
 | 
			
		||||
                n_heads=num_key_value_heads,
 | 
			
		||||
                dim1=key_value_dim,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(n_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(n_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
 | 
			
		||||
                        num_local_key_value_heads, dims_per_head, dim
 | 
			
		||||
                    )
 | 
			
		||||
                    for i in range(num_shards)
 | 
			
		||||
                ],
 | 
			
		||||
                dim=0,
 | 
			
		||||
            ).reshape(key_value_dim, dim)
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if num_shards == 1:
 | 
			
		||||
        # Unsharded
 | 
			
		||||
        state_dict.update(
 | 
			
		||||
            {
 | 
			
		||||
                "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
 | 
			
		||||
                "model.norm.weight": loaded["norm.weight"],
 | 
			
		||||
                "lm_head.weight": loaded["output.weight"],
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        state_dict.update(
 | 
			
		||||
            {
 | 
			
		||||
                "model.embed_tokens.weight": torch.cat(
 | 
			
		||||
                    [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
                ),
 | 
			
		||||
                "model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
 | 
			
		||||
                "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Load VQGAN weights
 | 
			
		||||
    vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
 | 
			
		||||
    vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
 | 
			
		||||
    for k, v in vqgan_state_dict.items():
 | 
			
		||||
        if "decoder" in k:
 | 
			
		||||
            continue  # we dont do image generation yet
 | 
			
		||||
        state_dict[f"model.vqmodel.{k}"] = v
 | 
			
		||||
 | 
			
		||||
    # Write configs
 | 
			
		||||
    ffn_dim_multiplier = params.get("ffn_dim_multiplier", 1)
 | 
			
		||||
    multiple_of = params.get("multiple_of", 256)
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
 | 
			
		||||
        tokenizer_config = json.load(tokenizer_file)
 | 
			
		||||
        vocabulary_map = tokenizer_config["model"]["vocab"]
 | 
			
		||||
        vocabulary_map["<image>"] = vocabulary_map[
 | 
			
		||||
            "<reserved08707>"
 | 
			
		||||
        ]  # use a reserved token instead of adding a new one
 | 
			
		||||
        del vocabulary_map["<reserved08707>"]
 | 
			
		||||
 | 
			
		||||
        for token in tokenizer_config["added_tokens"]:
 | 
			
		||||
            if token["content"] == "<reserved08707>":
 | 
			
		||||
                token["content"] = "<image>"
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
 | 
			
		||||
        json.dump(tokenizer_config, f)  # save the new file to init tokenizer later
 | 
			
		||||
 | 
			
		||||
    vq_keys_to_replace = [
 | 
			
		||||
        ("ch", "base_channels"),
 | 
			
		||||
        ("out_ch", "out_channels"),
 | 
			
		||||
        ("n_embed", "num_embeddings"),
 | 
			
		||||
        ("ch_mult", "channel_multiplier"),
 | 
			
		||||
        ("double_z", "double_latent"),
 | 
			
		||||
        ("z_channels", "latent_channels"),
 | 
			
		||||
    ]
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
 | 
			
		||||
        vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
 | 
			
		||||
        vq_config.update(**vq_config["ddconfig"])
 | 
			
		||||
        for old, new in vq_keys_to_replace:
 | 
			
		||||
            vq_config[new] = vq_config[old]
 | 
			
		||||
        del vq_config["ddconfig"]
 | 
			
		||||
        del vq_config["ckpt_path"]
 | 
			
		||||
        del vq_config["lossconfig"]
 | 
			
		||||
 | 
			
		||||
    config = ChameleonConfig(
 | 
			
		||||
        hidden_size=dim,
 | 
			
		||||
        intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
 | 
			
		||||
        num_attention_heads=params["n_heads"],
 | 
			
		||||
        num_hidden_layers=params["n_layers"],
 | 
			
		||||
        rms_norm_eps=params["norm_eps"],
 | 
			
		||||
        num_key_value_heads=num_key_value_heads,
 | 
			
		||||
        vocab_size=VOCAB_SIZE,
 | 
			
		||||
        rope_theta=base,
 | 
			
		||||
        max_position_embeddings=max_position_embeddings,
 | 
			
		||||
        model_parallel_size=model_parallel_size,
 | 
			
		||||
        swin_norm=swin_norm,
 | 
			
		||||
        vq_config=vq_config,
 | 
			
		||||
        vocabulary_map=vocabulary_map,
 | 
			
		||||
    )
 | 
			
		||||
    with init_empty_weights():
 | 
			
		||||
        model = ChameleonForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(state_dict, assign=True, strict=False)
 | 
			
		||||
    model.save_pretrained(model_path, safe_serialization=True)
 | 
			
		||||
 | 
			
		||||
    # Load and save the processor
 | 
			
		||||
    tokenizer = LlamaTokenizerFast(
 | 
			
		||||
        tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer.sep_token_id = 8710  # assign <reserved08706> to sep so that we can append it after input text
 | 
			
		||||
    tokenizer.pad_token_id = 1  # assign <pad> to special pad_token
 | 
			
		||||
    image_processor = ChameleonImageProcessor()
 | 
			
		||||
    processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    processor.save_pretrained(model_path)
 | 
			
		||||
 | 
			
		||||
    # Make space so we can load the model properly now.
 | 
			
		||||
    del state_dict
 | 
			
		||||
    del loaded
 | 
			
		||||
    del vqgan_state_dict
 | 
			
		||||
    gc.collect()
 | 
			
		||||
 | 
			
		||||
    # Short inference on a few examples to check if generation makes sense
 | 
			
		||||
    # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
 | 
			
		||||
    print("Loading the checkpoint in a Chameleon model...")
 | 
			
		||||
    print("*" * 100)
 | 
			
		||||
    model = ChameleonForConditionalGeneration.from_pretrained(
 | 
			
		||||
        model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
 | 
			
		||||
    )
 | 
			
		||||
    processor = ChameleonProcessor.from_pretrained(model_path)
 | 
			
		||||
 | 
			
		||||
    prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
 | 
			
		||||
    image = Image.open(
 | 
			
		||||
        requests.get(
 | 
			
		||||
            "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
 | 
			
		||||
        ).raw
 | 
			
		||||
    )
 | 
			
		||||
    inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
 | 
			
		||||
    length = inputs.input_ids.shape[1]
 | 
			
		||||
 | 
			
		||||
    out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
 | 
			
		||||
    generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
 | 
			
		||||
 | 
			
		||||
    print(f"Generation for single-image: {generated_text}")
 | 
			
		||||
    print("*" * 100)
 | 
			
		||||
 | 
			
		||||
    # Multi-image example
 | 
			
		||||
    prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
 | 
			
		||||
    image = Image.open(
 | 
			
		||||
        requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
 | 
			
		||||
    )
 | 
			
		||||
    image_2 = Image.open(
 | 
			
		||||
        requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
 | 
			
		||||
    length = inputs.input_ids.shape[1]
 | 
			
		||||
    out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
 | 
			
		||||
    generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
 | 
			
		||||
 | 
			
		||||
    print(f"Generation for multi-image: {generated_text}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--input_dir",
 | 
			
		||||
        help="Location of Chameleon weights",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_size",
 | 
			
		||||
        choices=["7B", "30B"],
 | 
			
		||||
        help=""
 | 
			
		||||
        " models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, check out the original repo: https://github.com/facebookresearch/chameleon",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        help="Location to write HF model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--test_inference",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to load the model for generation to test it's converted correctly.",
 | 
			
		||||
    )
 | 
			
		||||
    # Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--chameleon_version",
 | 
			
		||||
        choices=[1],
 | 
			
		||||
        default=1,
 | 
			
		||||
        type=int,
 | 
			
		||||
        help="Version of the Chameleon model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    write_model(
 | 
			
		||||
        model_path=args.output_dir,
 | 
			
		||||
        input_base_path=args.input_dir,
 | 
			
		||||
        model_size=args.model_size,
 | 
			
		||||
        chameleon_version=args.chameleon_version,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,134 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import ChineseCLIPConfig, ChineseCLIPModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
 | 
			
		||||
    q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
 | 
			
		||||
    q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
 | 
			
		||||
 | 
			
		||||
    out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
 | 
			
		||||
    out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.q_proj.weight.data = q_proj
 | 
			
		||||
    hf_attn_layer.q_proj.bias.data = q_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.k_proj.weight.data = k_proj
 | 
			
		||||
    hf_attn_layer.k_proj.bias.data = k_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.v_proj.weight.data = v_proj
 | 
			
		||||
    hf_attn_layer.v_proj.bias.data = v_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.out_proj.weight.data = out_proj_weights
 | 
			
		||||
    hf_attn_layer.out_proj.bias.data = out_proj_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_mlp(hf_mlp, pt_weights, prefix):
 | 
			
		||||
    copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
 | 
			
		||||
    copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_linear(hf_linear, pt_weights, prefix):
 | 
			
		||||
    hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
 | 
			
		||||
    hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layer(hf_layer, pt_weights, prefix):
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
 | 
			
		||||
    copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
 | 
			
		||||
 | 
			
		||||
    # copy MLP
 | 
			
		||||
    copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
 | 
			
		||||
 | 
			
		||||
    # copy attn
 | 
			
		||||
    copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layers(hf_layers, pt_weights, prefix):
 | 
			
		||||
    for layer_id, hf_layer in enumerate(hf_layers):
 | 
			
		||||
        copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_text_model_and_projection(hf_model, pt_weights):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
 | 
			
		||||
 | 
			
		||||
    # copy text encoder
 | 
			
		||||
    for name, param in hf_model.text_model.named_parameters():
 | 
			
		||||
        param.data = pt_weights[f"bert.{name}"].data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_vision_model_and_projection(hf_model, pt_weights):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
 | 
			
		||||
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
 | 
			
		||||
    copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
 | 
			
		||||
 | 
			
		||||
    # copy embeddings
 | 
			
		||||
    hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
 | 
			
		||||
    hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
 | 
			
		||||
    hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
 | 
			
		||||
 | 
			
		||||
    # copy encoder
 | 
			
		||||
    copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
 | 
			
		||||
    config = ChineseCLIPConfig.from_pretrained(config_path)
 | 
			
		||||
 | 
			
		||||
    hf_model = ChineseCLIPModel(config).eval()
 | 
			
		||||
 | 
			
		||||
    pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
 | 
			
		||||
    pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
 | 
			
		||||
 | 
			
		||||
    copy_text_model_and_projection(hf_model, pt_weights)
 | 
			
		||||
    copy_vision_model_and_projection(hf_model, pt_weights)
 | 
			
		||||
    hf_model.logit_scale.data = pt_weights["logit_scale"].data
 | 
			
		||||
 | 
			
		||||
    hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output folder storing converted hf PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
    print("The conversion is finished!")
 | 
			
		||||
@ -1,133 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
from laion_clap import CLAP_Module
 | 
			
		||||
 | 
			
		||||
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KEYS_TO_MODIFY_MAPPING = {
 | 
			
		||||
    "text_branch": "text_model",
 | 
			
		||||
    "audio_branch": "audio_model.audio_encoder",
 | 
			
		||||
    "attn": "attention.self",
 | 
			
		||||
    "self.proj": "output.dense",
 | 
			
		||||
    "attention.self_mask": "attn_mask",
 | 
			
		||||
    "mlp.fc1": "intermediate.dense",
 | 
			
		||||
    "mlp.fc2": "output.dense",
 | 
			
		||||
    "norm1": "layernorm_before",
 | 
			
		||||
    "norm2": "layernorm_after",
 | 
			
		||||
    "bn0": "batch_norm",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_clap(checkpoint_path, model_type, enable_fusion=False):
 | 
			
		||||
    model = CLAP_Module(
 | 
			
		||||
        amodel=model_type,
 | 
			
		||||
        enable_fusion=enable_fusion,
 | 
			
		||||
    )
 | 
			
		||||
    model.load_ckpt(checkpoint_path)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_config_from_original(clap_model):
 | 
			
		||||
    audio_config = {
 | 
			
		||||
        "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
 | 
			
		||||
        "depths": clap_model.model.audio_branch.depths,
 | 
			
		||||
        "hidden_size": clap_model.model.audio_projection[0].in_features,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
 | 
			
		||||
 | 
			
		||||
    return ClapConfig(audio_config=audio_config, text_config=text_config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict(state_dict):
 | 
			
		||||
    model_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    sequential_layers_pattern = r".*sequential.(\d+).*"
 | 
			
		||||
    text_projection_pattern = r".*_projection.(\d+).*"
 | 
			
		||||
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        # check if any key needs to be modified
 | 
			
		||||
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
 | 
			
		||||
            if key_to_modify in key:
 | 
			
		||||
                key = key.replace(key_to_modify, new_key)
 | 
			
		||||
 | 
			
		||||
        if re.match(sequential_layers_pattern, key):
 | 
			
		||||
            # replace sequential layers with list
 | 
			
		||||
            sequential_layer = re.match(sequential_layers_pattern, key).group(1)
 | 
			
		||||
 | 
			
		||||
            key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
 | 
			
		||||
        elif re.match(text_projection_pattern, key):
 | 
			
		||||
            projecton_layer = int(re.match(text_projection_pattern, key).group(1))
 | 
			
		||||
 | 
			
		||||
            # Because in CLAP they use `nn.Sequential`...
 | 
			
		||||
            transformers_projection_layer = 1 if projecton_layer == 0 else 2
 | 
			
		||||
 | 
			
		||||
            key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
 | 
			
		||||
 | 
			
		||||
        if "audio" and "qkv" in key:
 | 
			
		||||
            # split qkv into query key and value
 | 
			
		||||
            mixed_qkv = value
 | 
			
		||||
            qkv_dim = mixed_qkv.size(0) // 3
 | 
			
		||||
 | 
			
		||||
            query_layer = mixed_qkv[:qkv_dim]
 | 
			
		||||
            key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
 | 
			
		||||
            value_layer = mixed_qkv[qkv_dim * 2 :]
 | 
			
		||||
 | 
			
		||||
            model_state_dict[key.replace("qkv", "query")] = query_layer
 | 
			
		||||
            model_state_dict[key.replace("qkv", "key")] = key_layer
 | 
			
		||||
            model_state_dict[key.replace("qkv", "value")] = value_layer
 | 
			
		||||
        else:
 | 
			
		||||
            model_state_dict[key] = value
 | 
			
		||||
 | 
			
		||||
    return model_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
 | 
			
		||||
    clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
 | 
			
		||||
 | 
			
		||||
    clap_model.eval()
 | 
			
		||||
    state_dict = clap_model.model.state_dict()
 | 
			
		||||
    state_dict = rename_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    transformers_config = get_config_from_original(clap_model)
 | 
			
		||||
    transformers_config.audio_config.enable_fusion = enable_fusion
 | 
			
		||||
    model = ClapModel(transformers_config)
 | 
			
		||||
 | 
			
		||||
    # ignore the spectrogram embedding layer
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    transformers_config.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
 | 
			
		||||
    parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clap_checkpoint(
 | 
			
		||||
        args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
 | 
			
		||||
    )
 | 
			
		||||
@ -1,156 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from clip import load
 | 
			
		||||
 | 
			
		||||
from transformers import CLIPConfig, CLIPModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
 | 
			
		||||
    q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
 | 
			
		||||
    q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
 | 
			
		||||
 | 
			
		||||
    out_proj_weights = pt_attn_layer.out_proj.weight
 | 
			
		||||
    out_proj_bias = pt_attn_layer.out_proj.bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.q_proj.weight.data = q_proj
 | 
			
		||||
    hf_attn_layer.q_proj.bias.data = q_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.k_proj.weight.data = k_proj
 | 
			
		||||
    hf_attn_layer.k_proj.bias.data = k_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.v_proj.weight.data = v_proj
 | 
			
		||||
    hf_attn_layer.v_proj.bias.data = v_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.out_proj.weight = out_proj_weights
 | 
			
		||||
    hf_attn_layer.out_proj.bias = out_proj_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_mlp(hf_mlp, pt_mlp):
 | 
			
		||||
    copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
 | 
			
		||||
    copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_linear(hf_linear, pt_linear):
 | 
			
		||||
    hf_linear.weight = pt_linear.weight
 | 
			
		||||
    hf_linear.bias = pt_linear.bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layer(hf_layer, pt_layer):
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
 | 
			
		||||
    copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
 | 
			
		||||
 | 
			
		||||
    # copy MLP
 | 
			
		||||
    copy_mlp(hf_layer.mlp, pt_layer.mlp)
 | 
			
		||||
 | 
			
		||||
    # copy attn
 | 
			
		||||
    copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layers(hf_layers, pt_layers):
 | 
			
		||||
    for hf_layer, pt_layer in zip(hf_layers, pt_layers):
 | 
			
		||||
        copy_layer(hf_layer, pt_layer)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_encoder(hf_encoder, pt_model):
 | 
			
		||||
    # copy  embeds
 | 
			
		||||
    hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
 | 
			
		||||
    hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
 | 
			
		||||
 | 
			
		||||
    # copy layer norm
 | 
			
		||||
    copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
 | 
			
		||||
 | 
			
		||||
    # copy hidden layers
 | 
			
		||||
    copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_text_model_and_projection(hf_model, pt_model):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
 | 
			
		||||
 | 
			
		||||
    # copy text encoder
 | 
			
		||||
    copy_encoder(hf_model.text_model, pt_model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_vison_model_and_projection(hf_model, pt_model):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
 | 
			
		||||
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
 | 
			
		||||
    copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
 | 
			
		||||
 | 
			
		||||
    # copy embeds
 | 
			
		||||
    hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
 | 
			
		||||
    hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
 | 
			
		||||
    hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
 | 
			
		||||
 | 
			
		||||
    # copy encoder
 | 
			
		||||
    copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if config_path is not None:
 | 
			
		||||
        config = CLIPConfig.from_pretrained(config_path)
 | 
			
		||||
    else:
 | 
			
		||||
        config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
 | 
			
		||||
 | 
			
		||||
    hf_model = CLIPModel(config).eval()
 | 
			
		||||
 | 
			
		||||
    pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
 | 
			
		||||
    pt_model = pt_model.eval()
 | 
			
		||||
 | 
			
		||||
    copy_text_model_and_projection(hf_model, pt_model)
 | 
			
		||||
    copy_vison_model_and_projection(hf_model, pt_model)
 | 
			
		||||
    hf_model.logit_scale = pt_model.logit_scale
 | 
			
		||||
 | 
			
		||||
    # Use `eos_token` so the example is more meaningful
 | 
			
		||||
    input_ids = torch.tensor(
 | 
			
		||||
        [
 | 
			
		||||
            [config.text_config.bos_token_id]
 | 
			
		||||
            + list(range(3, 77))
 | 
			
		||||
            + [config.text_config.eos_token_id]
 | 
			
		||||
            + [config.text_config.pad_token_id]
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    pixel_values = torch.randn(1, 3, 224, 224)
 | 
			
		||||
 | 
			
		||||
    hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
 | 
			
		||||
    hf_logits_per_image = hf_outputs.logits_per_image
 | 
			
		||||
    hf_logits_per_text = hf_outputs.logits_per_text
 | 
			
		||||
    pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
 | 
			
		||||
 | 
			
		||||
    hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
@ -1,264 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    CLIPSegConfig,
 | 
			
		||||
    CLIPSegForImageSegmentation,
 | 
			
		||||
    CLIPSegProcessor,
 | 
			
		||||
    CLIPSegTextConfig,
 | 
			
		||||
    CLIPSegVisionConfig,
 | 
			
		||||
    CLIPTokenizer,
 | 
			
		||||
    ViTImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_clipseg_config(model_name):
 | 
			
		||||
    text_config = CLIPSegTextConfig()
 | 
			
		||||
    vision_config = CLIPSegVisionConfig(patch_size=16)
 | 
			
		||||
 | 
			
		||||
    use_complex_transposed_convolution = "refined" in model_name
 | 
			
		||||
    reduce_dim = 16 if "rd16" in model_name else 64
 | 
			
		||||
 | 
			
		||||
    config = CLIPSegConfig.from_text_vision_configs(
 | 
			
		||||
        text_config,
 | 
			
		||||
        vision_config,
 | 
			
		||||
        use_complex_transposed_convolution=use_complex_transposed_convolution,
 | 
			
		||||
        reduce_dim=reduce_dim,
 | 
			
		||||
    )
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    # update prefixes
 | 
			
		||||
    if "clip_model" in name:
 | 
			
		||||
        name = name.replace("clip_model", "clip")
 | 
			
		||||
    if "transformer" in name:
 | 
			
		||||
        if "visual" in name:
 | 
			
		||||
            name = name.replace("visual.transformer", "vision_model")
 | 
			
		||||
        else:
 | 
			
		||||
            name = name.replace("transformer", "text_model")
 | 
			
		||||
    if "resblocks" in name:
 | 
			
		||||
        name = name.replace("resblocks", "encoder.layers")
 | 
			
		||||
    if "ln_1" in name:
 | 
			
		||||
        name = name.replace("ln_1", "layer_norm1")
 | 
			
		||||
    if "ln_2" in name:
 | 
			
		||||
        name = name.replace("ln_2", "layer_norm2")
 | 
			
		||||
    if "c_fc" in name:
 | 
			
		||||
        name = name.replace("c_fc", "fc1")
 | 
			
		||||
    if "c_proj" in name:
 | 
			
		||||
        name = name.replace("c_proj", "fc2")
 | 
			
		||||
    if "attn" in name and "self" not in name:
 | 
			
		||||
        name = name.replace("attn", "self_attn")
 | 
			
		||||
    # text encoder
 | 
			
		||||
    if "token_embedding" in name:
 | 
			
		||||
        name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
 | 
			
		||||
    if "positional_embedding" in name and "visual" not in name:
 | 
			
		||||
        name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
 | 
			
		||||
    if "ln_final" in name:
 | 
			
		||||
        name = name.replace("ln_final", "text_model.final_layer_norm")
 | 
			
		||||
    # vision encoder
 | 
			
		||||
    if "visual.class_embedding" in name:
 | 
			
		||||
        name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
 | 
			
		||||
    if "visual.conv1" in name:
 | 
			
		||||
        name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
 | 
			
		||||
    if "visual.positional_embedding" in name:
 | 
			
		||||
        name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
 | 
			
		||||
    if "visual.ln_pre" in name:
 | 
			
		||||
        name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
 | 
			
		||||
    if "visual.ln_post" in name:
 | 
			
		||||
        name = name.replace("visual.ln_post", "vision_model.post_layernorm")
 | 
			
		||||
    # projection layers
 | 
			
		||||
    if "visual.proj" in name:
 | 
			
		||||
        name = name.replace("visual.proj", "visual_projection.weight")
 | 
			
		||||
    if "text_projection" in name:
 | 
			
		||||
        name = name.replace("text_projection", "text_projection.weight")
 | 
			
		||||
    # decoder
 | 
			
		||||
    if "trans_conv" in name:
 | 
			
		||||
        name = name.replace("trans_conv", "transposed_convolution")
 | 
			
		||||
    if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
 | 
			
		||||
        name = "decoder." + name
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "decoder.layers")
 | 
			
		||||
    if "linear1" in name:
 | 
			
		||||
        name = name.replace("linear1", "mlp.fc1")
 | 
			
		||||
    if "linear2" in name:
 | 
			
		||||
        name = name.replace("linear2", "mlp.fc2")
 | 
			
		||||
    if "norm1" in name and "layer_" not in name:
 | 
			
		||||
        name = name.replace("norm1", "layer_norm1")
 | 
			
		||||
    if "norm2" in name and "layer_" not in name:
 | 
			
		||||
        name = name.replace("norm2", "layer_norm2")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, config):
 | 
			
		||||
    for key in orig_state_dict.copy():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
 | 
			
		||||
        if key.startswith("clip_model") and "attn.in_proj" in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            if "visual" in key:
 | 
			
		||||
                layer_num = int(key_split[4])
 | 
			
		||||
                dim = config.vision_config.hidden_size
 | 
			
		||||
                prefix = "vision_model"
 | 
			
		||||
            else:
 | 
			
		||||
                layer_num = int(key_split[3])
 | 
			
		||||
                dim = config.text_config.hidden_size
 | 
			
		||||
                prefix = "text_model"
 | 
			
		||||
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
 | 
			
		||||
                    dim : dim * 2, :
 | 
			
		||||
                ]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
 | 
			
		||||
        elif "self_attn" in key and "out_proj" not in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            layer_num = int(key_split[1])
 | 
			
		||||
            dim = config.reduce_dim
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
 | 
			
		||||
        else:
 | 
			
		||||
            new_name = rename_key(key)
 | 
			
		||||
            if "visual_projection" in new_name or "text_projection" in new_name:
 | 
			
		||||
                val = val.T
 | 
			
		||||
            orig_state_dict[new_name] = val
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    image = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    config = get_clipseg_config(model_name)
 | 
			
		||||
    model = CLIPSegForImageSegmentation(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
 | 
			
		||||
 | 
			
		||||
    # remove some keys
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        if key.startswith("model"):
 | 
			
		||||
            state_dict.pop(key, None)
 | 
			
		||||
 | 
			
		||||
    # rename some keys
 | 
			
		||||
    state_dict = convert_state_dict(state_dict, config)
 | 
			
		||||
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
 | 
			
		||||
    if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
 | 
			
		||||
        raise ValueError(f"Missing keys that are not expected: {missing_keys}")
 | 
			
		||||
    if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
 | 
			
		||||
        raise ValueError(f"Unexpected keys: {unexpected_keys}")
 | 
			
		||||
 | 
			
		||||
    image_processor = ViTImageProcessor(size=352)
 | 
			
		||||
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
 | 
			
		||||
    processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
 | 
			
		||||
    image = prepare_img()
 | 
			
		||||
    text = ["a glass", "something to fill", "wood", "a jar"]
 | 
			
		||||
 | 
			
		||||
    inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = model(**inputs)
 | 
			
		||||
 | 
			
		||||
    # verify values
 | 
			
		||||
    expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
 | 
			
		||||
    expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
 | 
			
		||||
    if model_name == "clipseg-rd64-refined":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
 | 
			
		||||
        )
 | 
			
		||||
    elif model_name == "clipseg-rd64":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
 | 
			
		||||
        )
 | 
			
		||||
    elif model_name == "clipseg-rd16":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Model name {model_name} not supported.")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        print(f"Saving model and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing model and processor for {model_name} to the hub")
 | 
			
		||||
        model.push_to_hub(f"CIDAS/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"CIDAS/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="clipseg-rd64",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
 | 
			
		||||
        help=(
 | 
			
		||||
            "Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
 | 
			
		||||
            " reduce dimension)"
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path",
 | 
			
		||||
        default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help=(
 | 
			
		||||
            "Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
 | 
			
		||||
            " the decoder weights."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,234 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Weights conversion script for CLVP
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_MODELS = {
 | 
			
		||||
    "clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
 | 
			
		||||
    "decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
dim = 1024
 | 
			
		||||
sub_dim = dim // 16
 | 
			
		||||
 | 
			
		||||
CLVP_ENCODERS_MAPPING = {
 | 
			
		||||
    "text_transformer.transformer.attn_layers": "text_encoder_model",
 | 
			
		||||
    "speech_transformer.transformer.attn_layers": "speech_encoder_model",
 | 
			
		||||
    "text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
 | 
			
		||||
    "speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
 | 
			
		||||
    "to_text_latent": "text_encoder_model.projection",
 | 
			
		||||
    "to_speech_latent": "speech_encoder_model.projection",
 | 
			
		||||
    "text_emb": "text_encoder_model.token_embedding",
 | 
			
		||||
    "speech_emb": "speech_encoder_model.token_embedding",
 | 
			
		||||
    "1.wrap.net.0": "mlp.fc1",
 | 
			
		||||
    "1.wrap.net.3": "mlp.fc2",
 | 
			
		||||
    "1.wrap": "self_attn",
 | 
			
		||||
    "to_out": "out_proj",
 | 
			
		||||
    "to_q": "q_proj",
 | 
			
		||||
    "to_k": "k_proj",
 | 
			
		||||
    "to_v": "v_proj",
 | 
			
		||||
    "temperature": "logit_scale",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CLVP_DECODER_MAPPING = {
 | 
			
		||||
    "conditioning_encoder.init": "conditioning_encoder.mel_conv",
 | 
			
		||||
    "conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
 | 
			
		||||
    "mel_attn_blocks": "group_norms",
 | 
			
		||||
    ".norm.weight": ".weight",
 | 
			
		||||
    ".norm.bias": ".bias",
 | 
			
		||||
    "text_embedding": "conditioning_encoder.text_token_embedding",
 | 
			
		||||
    "text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
 | 
			
		||||
    "final_norm": "speech_decoder_model.final_norm",
 | 
			
		||||
    "mel_head": "speech_decoder_model.lm_head",
 | 
			
		||||
    "gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
 | 
			
		||||
    "mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
 | 
			
		||||
    "mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
 | 
			
		||||
    "gpt.h": "speech_decoder_model.model.decoder.layers",
 | 
			
		||||
    "ln_1": "input_layernorm",
 | 
			
		||||
    "ln_2": "post_attention_layernorm",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_index(present_index):
 | 
			
		||||
    if present_index % 2 == 0:
 | 
			
		||||
        return int(present_index / 2)
 | 
			
		||||
    else:
 | 
			
		||||
        return int((present_index - 1) / 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_encoder_weights(original_weights):
 | 
			
		||||
    converted_weights = {}
 | 
			
		||||
    original_weights_keys = sorted(original_weights.keys())
 | 
			
		||||
    for original_key in original_weights_keys:
 | 
			
		||||
        updated_key = original_key
 | 
			
		||||
        # for input_rmsnorm.weight and post_attention_rmsnorm.weight
 | 
			
		||||
        if "0.0.g" in updated_key:
 | 
			
		||||
            present_index = updated_key.split(".")[4]
 | 
			
		||||
            if int(present_index) % 2 == 0:
 | 
			
		||||
                updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
 | 
			
		||||
            else:
 | 
			
		||||
                updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
 | 
			
		||||
 | 
			
		||||
        if "transformer.attn_layers.layers" in updated_key:
 | 
			
		||||
            present_index = updated_key.split(".")[4]
 | 
			
		||||
            updated_index = update_index(int(present_index))
 | 
			
		||||
            updated_key = updated_key.replace(
 | 
			
		||||
                f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for k, v in CLVP_ENCODERS_MAPPING.items():
 | 
			
		||||
            if k in updated_key:
 | 
			
		||||
                updated_key = updated_key.replace(k, v)
 | 
			
		||||
 | 
			
		||||
        converted_weights[updated_key] = original_weights.pop(original_key)
 | 
			
		||||
 | 
			
		||||
    return converted_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_decoder_weights(original_weights):
 | 
			
		||||
    converted_weights = {}
 | 
			
		||||
    original_weights_keys = sorted(original_weights.keys())
 | 
			
		||||
    for original_key in original_weights_keys:
 | 
			
		||||
        updated_key = original_key
 | 
			
		||||
        if len(updated_key.split(".")) > 3:
 | 
			
		||||
            index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
 | 
			
		||||
 | 
			
		||||
        # for decoder attention
 | 
			
		||||
        if "attn.c_attn" in updated_key:
 | 
			
		||||
            if attr == "weight":
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "attn.c_proj" in updated_key:
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
 | 
			
		||||
                original_weights[updated_key].squeeze(-1).T
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
 | 
			
		||||
            original_weights.pop(updated_key)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # conditional encoder attention
 | 
			
		||||
        if "qkv" in updated_key:
 | 
			
		||||
            if attr == "weight":
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
 | 
			
		||||
 | 
			
		||||
            indices = torch.arange(dim)
 | 
			
		||||
            index1, index2, index3 = (
 | 
			
		||||
                indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
                indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
                indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index1], slice2[index3], slice3[index2]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index2], slice2[index1], slice3[index3]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index3], slice2[index2], slice3[index1]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "proj_out" in updated_key:
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
 | 
			
		||||
                updated_key
 | 
			
		||||
            ].squeeze(-1)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        for k, v in CLVP_DECODER_MAPPING.items():
 | 
			
		||||
            if k in updated_key:
 | 
			
		||||
                updated_key = updated_key.replace(k, v)
 | 
			
		||||
 | 
			
		||||
        converted_weights[updated_key] = original_weights.pop(original_key)
 | 
			
		||||
 | 
			
		||||
    return converted_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _download(url: str, root: str):
 | 
			
		||||
    repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
 | 
			
		||||
    filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
 | 
			
		||||
    hf_hub_download(
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        filename=filename,
 | 
			
		||||
        force_filename=root,
 | 
			
		||||
        local_dir_use_symlinks=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
 | 
			
		||||
    converted_checkpoint = {}
 | 
			
		||||
 | 
			
		||||
    for each_model_name, each_model_url in _MODELS.items():
 | 
			
		||||
        each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
 | 
			
		||||
        if not os.path.exists(each_model_path):
 | 
			
		||||
            print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
 | 
			
		||||
            _download(url=each_model_url, root=each_model_path)
 | 
			
		||||
 | 
			
		||||
        if each_model_name == "clvp":
 | 
			
		||||
            clvp_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
 | 
			
		||||
        else:
 | 
			
		||||
            decoder_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
 | 
			
		||||
 | 
			
		||||
    # Converting the weights
 | 
			
		||||
    converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
 | 
			
		||||
    converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
 | 
			
		||||
 | 
			
		||||
    config = ClvpConfig.from_pretrained("susnato/clvp_dev")
 | 
			
		||||
    model = ClvpModelForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(converted_checkpoint, strict=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Model saved at {pytorch_dump_folder_path}!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model. (Please enter full path)",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,214 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
Convert ColPali weights from the original repository to the HF model format.
 | 
			
		||||
 | 
			
		||||
Original repository: https://github.com/illuin-tech/colpali.
 | 
			
		||||
 | 
			
		||||
NOTE: This script was originally run using `torch==2.5.1` and with:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
    --model_id vidore/colpali-v1.2-merged \
 | 
			
		||||
    --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
 | 
			
		||||
    --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
    --output_dir vidore/colpali-v1.2-hf-internal \
 | 
			
		||||
    --push_to_hub
 | 
			
		||||
 | 
			
		||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
    --model_id vidore/colpali-v1.3-merged \
 | 
			
		||||
    --revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
 | 
			
		||||
    --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
    --output_dir vidore/colpali-v1.3-hf \
 | 
			
		||||
    --push_to_hub
 | 
			
		||||
```
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import glob
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
from transformers.models.colpali import ColPaliForRetrieval
 | 
			
		||||
from transformers.models.colpali.configuration_colpali import ColPaliConfig
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ORIGINAL_DTYPE = torch.bfloat16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
 | 
			
		||||
    new_state_dict = {}
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        new_key = key
 | 
			
		||||
        if key.startswith("custom_text_proj"):
 | 
			
		||||
            new_key = key.replace("custom_text_proj", "embedding_proj_layer")
 | 
			
		||||
        if key.startswith("model."):
 | 
			
		||||
            new_key = key.replace("model.", "vlm.", 1)
 | 
			
		||||
        new_state_dict[new_key] = value
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
 | 
			
		||||
    directory_path = snapshot_download(
 | 
			
		||||
        repo_id=model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        allow_patterns=["*.safetensors"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    for path in glob.glob(f"{directory_path}/*"):
 | 
			
		||||
        if path.endswith(".safetensors"):
 | 
			
		||||
            with safe_open(path, framework="pt", device="cpu") as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
 | 
			
		||||
    if "lm_head.weight" not in original_state_dict:
 | 
			
		||||
        original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
 | 
			
		||||
            "model.language_model.model.embed_tokens.weight"
 | 
			
		||||
        ].clone()
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_colpali_weights_to_hf(
 | 
			
		||||
    model_id: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    push_to_hub: bool,
 | 
			
		||||
    revision: Optional[str] = None,
 | 
			
		||||
    original_vlm_name_or_path: Optional[str] = None,
 | 
			
		||||
):
 | 
			
		||||
    # Load the original model data
 | 
			
		||||
    original_config = AutoConfig.from_pretrained(
 | 
			
		||||
        model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
    if original_vlm_name_or_path is not None:
 | 
			
		||||
        original_config._name_or_path = original_vlm_name_or_path
 | 
			
		||||
    if hasattr(original_config, "architectures"):
 | 
			
		||||
        delattr(original_config, "architectures")
 | 
			
		||||
 | 
			
		||||
    original_state_dict = load_original_state_dict(model_id, revision=revision)
 | 
			
		||||
 | 
			
		||||
    # Format the state_dict keys
 | 
			
		||||
    original_state_dict = rename_state_dict_keys(original_state_dict)
 | 
			
		||||
 | 
			
		||||
    # Create the new config
 | 
			
		||||
    config = ColPaliConfig(
 | 
			
		||||
        vlm_config=original_config,
 | 
			
		||||
        embedding_dim=128,  # hardcoded in the original model
 | 
			
		||||
    )
 | 
			
		||||
    config.model_type = "colpali"
 | 
			
		||||
    config.is_composition = False
 | 
			
		||||
 | 
			
		||||
    # Load the untrained model
 | 
			
		||||
    model = ColPaliForRetrieval(config=config).to("cpu").eval()
 | 
			
		||||
    print("Created model with new config and randomly initialized weights")
 | 
			
		||||
 | 
			
		||||
    # NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
 | 
			
		||||
    # There are two ways to set the model's dtype:
 | 
			
		||||
    # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
 | 
			
		||||
    # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
 | 
			
		||||
    # The following snippet allows a fine-grained control over the model's dtype, making sure that all
 | 
			
		||||
    # the new weights' dtypes match the original model.
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        param.data = param.data.to(ORIGINAL_DTYPE)
 | 
			
		||||
    print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
 | 
			
		||||
 | 
			
		||||
    # Load the original weights
 | 
			
		||||
    model.load_state_dict(original_state_dict)
 | 
			
		||||
    print("Loaded original model weights")
 | 
			
		||||
 | 
			
		||||
    # Tie the weights (following ColPali's `__init__`` step)
 | 
			
		||||
    if model.vlm.language_model._tied_weights_keys is not None:
 | 
			
		||||
        model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
 | 
			
		||||
 | 
			
		||||
    # Sanity check: ensure all keys are the same
 | 
			
		||||
    state_dict_keys_old = set(original_state_dict.keys())
 | 
			
		||||
    state_dict_keys_new = set(model.state_dict().keys())
 | 
			
		||||
    disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
 | 
			
		||||
    if disjoint_keys:
 | 
			
		||||
        raise ValueError(f"Incompatible keys: {disjoint_keys}")
 | 
			
		||||
 | 
			
		||||
    # Save the model
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub(output_dir, private=True)
 | 
			
		||||
        print(f"Model pushed to the hub at `{output_dir}`")
 | 
			
		||||
    else:
 | 
			
		||||
        Path(output_dir).mkdir(exist_ok=True, parents=True)
 | 
			
		||||
        model.save_pretrained(output_dir)
 | 
			
		||||
        print(f"Model saved to `{output_dir}`")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="""
 | 
			
		||||
        This script converts the original ColPali model to the HF model format.
 | 
			
		||||
 | 
			
		||||
        Example usage:
 | 
			
		||||
        ```bash
 | 
			
		||||
        python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
            --model_id vidore/colpali-v1.2-merged \
 | 
			
		||||
            --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
 | 
			
		||||
            --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
            --output_dir vidore/colpali-v1.2-hf \
 | 
			
		||||
            --push_to_hub
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_id",
 | 
			
		||||
        help="Model ID of the original model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        help="Location to write HF model and tokenizer",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        help="Revision of the model to download",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--original_vlm_name_or_path",
 | 
			
		||||
        help="Name or path of the original VLM backbone model",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_colpali_weights_to_hf(
 | 
			
		||||
        model_id=args.model_id,
 | 
			
		||||
        output_dir=args.output_dir,
 | 
			
		||||
        push_to_hub=args.push_to_hub,
 | 
			
		||||
        revision=args.revision,
 | 
			
		||||
        original_vlm_name_or_path=args.original_vlm_name_or_path,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,212 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2025 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
Convert ColQwen2 weights from the original repository to the HF model format.
 | 
			
		||||
 | 
			
		||||
Don't forget to manually upload the processor-related files to the HF model repository
 | 
			
		||||
after running this script.
 | 
			
		||||
 | 
			
		||||
Original repository: https://github.com/illuin-tech/colqwen2.
 | 
			
		||||
 | 
			
		||||
NOTE: This script was originally run using `torch==2.5.1` and with:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
 | 
			
		||||
    --model_id vidore/colqwen2-v1.0-merged \
 | 
			
		||||
    --revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
 | 
			
		||||
    --original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
 | 
			
		||||
    --output_dir vidore/colqwen2-v1.0-hf-internal \
 | 
			
		||||
    --push_to_hub
 | 
			
		||||
```
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import glob
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
from transformers.models.colqwen2 import ColQwen2ForRetrieval
 | 
			
		||||
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ORIGINAL_DTYPE = torch.bfloat16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
 | 
			
		||||
    directory_path = snapshot_download(
 | 
			
		||||
        repo_id=model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        allow_patterns=["*.safetensors"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    for path in glob.glob(f"{directory_path}/*"):
 | 
			
		||||
        if path.endswith(".safetensors"):
 | 
			
		||||
            with safe_open(path, framework="pt", device="cpu") as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
 | 
			
		||||
    if "lm_head.weight" not in original_state_dict:
 | 
			
		||||
        original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
 | 
			
		||||
    new_state_dict: dict[str, Any] = {}
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        if key.startswith("custom_text_proj"):
 | 
			
		||||
            new_key = key.replace("custom_text_proj", "embedding_proj_layer")
 | 
			
		||||
        else:
 | 
			
		||||
            # The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
 | 
			
		||||
            # to all remaining keys.
 | 
			
		||||
            if key.startswith("model."):
 | 
			
		||||
                key = key.replace("model.", "model.language_model.")
 | 
			
		||||
            if key.startswith("visual."):
 | 
			
		||||
                key = key.replace("visual.", "model.visual.")
 | 
			
		||||
            new_key = "vlm." + key
 | 
			
		||||
        new_state_dict[new_key] = value
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_colqwen2_weights_to_hf(
 | 
			
		||||
    model_id: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    push_to_hub: bool,
 | 
			
		||||
    revision: Optional[str] = None,
 | 
			
		||||
    original_vlm_name_or_path: Optional[str] = None,
 | 
			
		||||
):
 | 
			
		||||
    # Load the original model data
 | 
			
		||||
    original_config = AutoConfig.from_pretrained(
 | 
			
		||||
        model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
    if original_vlm_name_or_path is not None:
 | 
			
		||||
        original_config._name_or_path = original_vlm_name_or_path
 | 
			
		||||
    if hasattr(original_config, "architectures"):
 | 
			
		||||
        delattr(original_config, "architectures")
 | 
			
		||||
 | 
			
		||||
    original_state_dict = load_original_state_dict(model_id, revision=revision)
 | 
			
		||||
 | 
			
		||||
    # Format the state_dict keys
 | 
			
		||||
    original_state_dict = rename_state_dict_keys(original_state_dict)
 | 
			
		||||
 | 
			
		||||
    # Create the new config
 | 
			
		||||
    config = ColQwen2Config(
 | 
			
		||||
        vlm_config=original_config,
 | 
			
		||||
        embedding_dim=128,  # hardcoded in the original model
 | 
			
		||||
    )
 | 
			
		||||
    config.model_type = "colqwen2"
 | 
			
		||||
    config.is_composition = False
 | 
			
		||||
 | 
			
		||||
    # Load the untrained model
 | 
			
		||||
    model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
 | 
			
		||||
    print("Created model with new config and randomly initialized weights")
 | 
			
		||||
 | 
			
		||||
    # NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
 | 
			
		||||
    # There are two ways to set the model's dtype:
 | 
			
		||||
    # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
 | 
			
		||||
    # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
 | 
			
		||||
    # The following snippet allows a fine-grained control over the model's dtype, making sure that all
 | 
			
		||||
    # the new weights' dtypes match the original model.
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        param.data = param.data.to(ORIGINAL_DTYPE)
 | 
			
		||||
    print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
 | 
			
		||||
 | 
			
		||||
    # Load the original weights
 | 
			
		||||
    model.load_state_dict(original_state_dict)
 | 
			
		||||
    print("Loaded original model weights")
 | 
			
		||||
 | 
			
		||||
    # # Sanity check: ensure all keys are the same
 | 
			
		||||
    state_dict_keys_old = set(original_state_dict.keys())
 | 
			
		||||
    state_dict_keys_new = set(model.state_dict().keys())
 | 
			
		||||
    disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
 | 
			
		||||
    if disjoint_keys:
 | 
			
		||||
        raise ValueError(f"Incompatible keys: {disjoint_keys}")
 | 
			
		||||
 | 
			
		||||
    # Save the model
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub(output_dir, private=True)
 | 
			
		||||
        print(f"Model pushed to the hub at `{output_dir}`")
 | 
			
		||||
    else:
 | 
			
		||||
        Path(output_dir).mkdir(exist_ok=True, parents=True)
 | 
			
		||||
        model.save_pretrained(output_dir)
 | 
			
		||||
        print(f"Model saved to `{output_dir}`")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="""
 | 
			
		||||
        This script converts the original ColQwen2 model to the HF model format.
 | 
			
		||||
 | 
			
		||||
        Don't forget to manually upload the processor-related files to the HF model repository
 | 
			
		||||
        after running this script.
 | 
			
		||||
 | 
			
		||||
        Example usage:
 | 
			
		||||
        ```bash
 | 
			
		||||
        python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
 | 
			
		||||
            --model_id vidore/colqwen2-v1.0-merged \
 | 
			
		||||
            --revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
 | 
			
		||||
            --original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
 | 
			
		||||
            --output_dir vidore/colqwen2-v1.0-hf-internal \
 | 
			
		||||
            --push_to_hub
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_id",
 | 
			
		||||
        help="Model ID of the original model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        help="Location to write HF model and tokenizer",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        help="Revision of the model to download",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--original_vlm_name_or_path",
 | 
			
		||||
        help="Name or path of the original VLM backbone model",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_colqwen2_weights_to_hf(
 | 
			
		||||
        model_id=args.model_id,
 | 
			
		||||
        output_dir=args.output_dir,
 | 
			
		||||
        push_to_hub=args.push_to_hub,
 | 
			
		||||
        revision=args.revision,
 | 
			
		||||
        original_vlm_name_or_path=args.original_vlm_name_or_path,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,324 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Conditional DETR checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    ConditionalDetrConfig,
 | 
			
		||||
    ConditionalDetrForObjectDetection,
 | 
			
		||||
    ConditionalDetrForSegmentation,
 | 
			
		||||
    ConditionalDetrImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
rename_keys = []
 | 
			
		||||
for i in range(6):
 | 
			
		||||
    # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
    # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
 | 
			
		||||
            f"decoder.layers.{i}.encoder_attn.out_proj.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
 | 
			
		||||
            f"decoder.layers.{i}.encoder_attn.out_proj.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # q, k, v projections in self/cross-attention in decoder for conditional DETR
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
 | 
			
		||||
# for conditional DETR, also convert reference point head and query scale MLP
 | 
			
		||||
rename_keys.extend(
 | 
			
		||||
    [
 | 
			
		||||
        ("input_proj.weight", "input_projection.weight"),
 | 
			
		||||
        ("input_proj.bias", "input_projection.bias"),
 | 
			
		||||
        ("query_embed.weight", "query_position_embeddings.weight"),
 | 
			
		||||
        ("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
 | 
			
		||||
        ("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
 | 
			
		||||
        ("class_embed.weight", "class_labels_classifier.weight"),
 | 
			
		||||
        ("class_embed.bias", "class_labels_classifier.bias"),
 | 
			
		||||
        ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
 | 
			
		||||
        ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
 | 
			
		||||
        ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
 | 
			
		||||
        ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
 | 
			
		||||
        ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
 | 
			
		||||
        ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
 | 
			
		||||
        ("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
 | 
			
		||||
        ("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
 | 
			
		||||
    ]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(state_dict, old, new):
 | 
			
		||||
    val = state_dict.pop(old)
 | 
			
		||||
    state_dict[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_backbone_keys(state_dict):
 | 
			
		||||
    new_state_dict = OrderedDict()
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        if "backbone.0.body" in key:
 | 
			
		||||
            new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
 | 
			
		||||
            new_state_dict[new_key] = value
 | 
			
		||||
        else:
 | 
			
		||||
            new_state_dict[key] = value
 | 
			
		||||
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_k_v(state_dict, is_panoptic=False):
 | 
			
		||||
    prefix = ""
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        prefix = "conditional_detr."
 | 
			
		||||
 | 
			
		||||
    # first: transformer encoder
 | 
			
		||||
    for i in range(6):
 | 
			
		||||
        # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load default config
 | 
			
		||||
    config = ConditionalDetrConfig()
 | 
			
		||||
    # set backbone and dilation attributes
 | 
			
		||||
    if "resnet101" in model_name:
 | 
			
		||||
        config.backbone = "resnet101"
 | 
			
		||||
    if "dc5" in model_name:
 | 
			
		||||
        config.dilation = True
 | 
			
		||||
    is_panoptic = "panoptic" in model_name
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        config.num_labels = 250
 | 
			
		||||
    else:
 | 
			
		||||
        config.num_labels = 91
 | 
			
		||||
        repo_id = "huggingface/label-files"
 | 
			
		||||
        filename = "coco-detection-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    # load image processor
 | 
			
		||||
    format = "coco_panoptic" if is_panoptic else "coco_detection"
 | 
			
		||||
    image_processor = ConditionalDetrImageProcessor(format=format)
 | 
			
		||||
 | 
			
		||||
    # prepare image
 | 
			
		||||
    img = prepare_img()
 | 
			
		||||
    encoding = image_processor(images=img, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Converting model {model_name}...")
 | 
			
		||||
 | 
			
		||||
    # load original model from torch hub
 | 
			
		||||
    conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
 | 
			
		||||
    state_dict = conditional_detr.state_dict()
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        if is_panoptic:
 | 
			
		||||
            src = "conditional_detr." + src
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    state_dict = rename_backbone_keys(state_dict)
 | 
			
		||||
    # query, key and value matrices need special treatment
 | 
			
		||||
    read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
 | 
			
		||||
    # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
 | 
			
		||||
    prefix = "conditional_detr.model." if is_panoptic else "model."
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        if is_panoptic:
 | 
			
		||||
            if (
 | 
			
		||||
                key.startswith("conditional_detr")
 | 
			
		||||
                and not key.startswith("class_labels_classifier")
 | 
			
		||||
                and not key.startswith("bbox_predictor")
 | 
			
		||||
            ):
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict["conditional_detr.model" + key[4:]] = val
 | 
			
		||||
            elif "class_labels_classifier" in key or "bbox_predictor" in key:
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict["conditional_detr." + key] = val
 | 
			
		||||
            elif key.startswith("bbox_attention") or key.startswith("mask_head"):
 | 
			
		||||
                continue
 | 
			
		||||
            else:
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict[prefix + key] = val
 | 
			
		||||
        else:
 | 
			
		||||
            if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict[prefix + key] = val
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
 | 
			
		||||
    # verify our conversion
 | 
			
		||||
    original_outputs = conditional_detr(pixel_values)
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
    assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
 | 
			
		||||
    assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    # Save model and image processor
 | 
			
		||||
    logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="conditional_detr_resnet50",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the CONDITIONAL_DETR model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,57 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvBERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
 | 
			
		||||
    conf = ConvBertConfig.from_json_file(convbert_config_file)
 | 
			
		||||
    model = ConvBertModel(conf)
 | 
			
		||||
 | 
			
		||||
    model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
 | 
			
		||||
    tf_model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--convbert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained ConvBERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,242 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvNext checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/facebookresearch/ConvNeXt"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_convnext_config(checkpoint_url):
 | 
			
		||||
    config = ConvNextConfig()
 | 
			
		||||
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 9, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "small" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [128, 256, 512, 1024]
 | 
			
		||||
    if "large" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [192, 384, 768, 1536]
 | 
			
		||||
    if "xlarge" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [256, 512, 1024, 2048]
 | 
			
		||||
 | 
			
		||||
    if "1k" in checkpoint_url:
 | 
			
		||||
        num_labels = 1000
 | 
			
		||||
        filename = "imagenet-1k-id2label.json"
 | 
			
		||||
        expected_shape = (1, 1000)
 | 
			
		||||
    else:
 | 
			
		||||
        num_labels = 21841
 | 
			
		||||
        filename = "imagenet-22k-id2label.json"
 | 
			
		||||
        expected_shape = (1, 21841)
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    config.num_labels = num_labels
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    if "1k" not in checkpoint_url:
 | 
			
		||||
        # this dataset contains 21843 labels but the model only has 21841
 | 
			
		||||
        # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
 | 
			
		||||
        del id2label[9205]
 | 
			
		||||
        del id2label[15027]
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    config.hidden_sizes = hidden_sizes
 | 
			
		||||
    config.depths = depths
 | 
			
		||||
 | 
			
		||||
    return config, expected_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "downsample_layers.0.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
 | 
			
		||||
    if "downsample_layers.0.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on
 | 
			
		||||
    if "downsample_layers.1.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.1.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.2.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.2.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.3.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.3.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
 | 
			
		||||
    if "stages" in name and "downsampling_layer" not in name:
 | 
			
		||||
        # stages.0.0. for instance should be renamed to stages.0.layers.0.
 | 
			
		||||
        name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
 | 
			
		||||
    if "stages" in name:
 | 
			
		||||
        name = name.replace("stages", "encoder.stages")
 | 
			
		||||
    if "norm" in name:
 | 
			
		||||
        name = name.replace("norm", "layernorm")
 | 
			
		||||
    if "gamma" in name:
 | 
			
		||||
        name = name.replace("gamma", "layer_scale_parameter")
 | 
			
		||||
    if "head" in name:
 | 
			
		||||
        name = name.replace("head", "classifier")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ConvNext structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define ConvNext configuration based on URL
 | 
			
		||||
    config, expected_shape = get_convnext_config(checkpoint_url)
 | 
			
		||||
    # load original state_dict from URL
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val
 | 
			
		||||
    # add prefix to all keys expect classifier head
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if not key.startswith("classifier"):
 | 
			
		||||
            key = "convnext." + key
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = ConvNextForImageClassification(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image, prepared by ConvNextImageProcessor
 | 
			
		||||
    size = 224 if "224" in checkpoint_url else 384
 | 
			
		||||
    image_processor = ConvNextImageProcessor(size=size)
 | 
			
		||||
    pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    logits = model(pixel_values).logits
 | 
			
		||||
 | 
			
		||||
    # note: the logits below were obtained without center cropping
 | 
			
		||||
    if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.4525, 0.7539, 0.0308])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.3561, 0.6350, -0.0384])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.4174, -0.0989, 0.1489])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.2513, -0.1349, -0.1613])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.2980, 0.3631, -0.1198])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.2963, 0.1227, 0.1723])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.7956, 0.8390, 0.2820])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.2681, 0.2365, 0.6246])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown URL: {checkpoint_url}")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3)
 | 
			
		||||
    assert logits.shape == expected_shape
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    print(f"Saving model to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Saving image processor to {pytorch_dump_folder_path}")
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    print("Pushing model to the hub...")
 | 
			
		||||
    model_name = "convnext"
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        model_name += "-tiny"
 | 
			
		||||
    elif "small" in checkpoint_url:
 | 
			
		||||
        model_name += "-small"
 | 
			
		||||
    elif "base" in checkpoint_url:
 | 
			
		||||
        model_name += "-base"
 | 
			
		||||
    elif "xlarge" in checkpoint_url:
 | 
			
		||||
        model_name += "-xlarge"
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        model_name += "-large"
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        model_name += "-224"
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        model_name += "-384"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" not in checkpoint_url:
 | 
			
		||||
        model_name += "-22k"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-22k-1k"
 | 
			
		||||
 | 
			
		||||
    model.push_to_hub(
 | 
			
		||||
        repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
 | 
			
		||||
        organization="nielsr",
 | 
			
		||||
        commit_message="Add model",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL of the original ConvNeXT checkpoint you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,286 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvNeXTV2 checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/facebookresearch/ConvNeXt"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_convnextv2_config(checkpoint_url):
 | 
			
		||||
    config = ConvNextV2Config()
 | 
			
		||||
 | 
			
		||||
    if "atto" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [40, 80, 160, 320]
 | 
			
		||||
    if "femto" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [48, 96, 192, 384]
 | 
			
		||||
    if "pico" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [64, 128, 256, 512]
 | 
			
		||||
    if "nano" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 8, 2]
 | 
			
		||||
        hidden_sizes = [80, 160, 320, 640]
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 9, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [128, 256, 512, 1024]
 | 
			
		||||
    if "large" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [192, 384, 768, 1536]
 | 
			
		||||
    if "huge" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [352, 704, 1408, 2816]
 | 
			
		||||
 | 
			
		||||
    num_labels = 1000
 | 
			
		||||
    filename = "imagenet-1k-id2label.json"
 | 
			
		||||
    expected_shape = (1, 1000)
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    config.num_labels = num_labels
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    config.hidden_sizes = hidden_sizes
 | 
			
		||||
    config.depths = depths
 | 
			
		||||
 | 
			
		||||
    return config, expected_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "downsample_layers.0.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
 | 
			
		||||
    if "downsample_layers.0.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on
 | 
			
		||||
    if "downsample_layers.1.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.1.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.2.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.2.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.3.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.3.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
 | 
			
		||||
    if "stages" in name and "downsampling_layer" not in name:
 | 
			
		||||
        # stages.0.0. for instance should be renamed to stages.0.layers.0.
 | 
			
		||||
        name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
 | 
			
		||||
    if "gamma" in name:
 | 
			
		||||
        name = name.replace("gamma", "weight")
 | 
			
		||||
    if "beta" in name:
 | 
			
		||||
        name = name.replace("beta", "bias")
 | 
			
		||||
    if "stages" in name:
 | 
			
		||||
        name = name.replace("stages", "encoder.stages")
 | 
			
		||||
    if "norm" in name:
 | 
			
		||||
        name = name.replace("norm", "layernorm")
 | 
			
		||||
    if "head" in name:
 | 
			
		||||
        name = name.replace("head", "classifier")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_preprocessor(checkpoint_url):
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        size = 224
 | 
			
		||||
        crop_pct = 224 / 256
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        size = 384
 | 
			
		||||
        crop_pct = None
 | 
			
		||||
    else:
 | 
			
		||||
        size = 512
 | 
			
		||||
        crop_pct = None
 | 
			
		||||
 | 
			
		||||
    return ConvNextImageProcessor(
 | 
			
		||||
        size=size,
 | 
			
		||||
        crop_pct=crop_pct,
 | 
			
		||||
        image_mean=[0.485, 0.456, 0.406],
 | 
			
		||||
        image_std=[0.229, 0.224, 0.225],
 | 
			
		||||
        resample=PILImageResampling.BICUBIC,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
 | 
			
		||||
    """
 | 
			
		||||
    print("Downloading original model from checkpoint...")
 | 
			
		||||
    # define ConvNeXTV2 configuration based on URL
 | 
			
		||||
    config, expected_shape = get_convnextv2_config(checkpoint_url)
 | 
			
		||||
    # load original state_dict from URL
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
 | 
			
		||||
 | 
			
		||||
    print("Converting model parameters...")
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val
 | 
			
		||||
    # add prefix to all keys expect classifier head
 | 
			
		||||
    for key in state_dict.copy():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if not key.startswith("classifier"):
 | 
			
		||||
            key = "convnextv2." + key
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = ConvNextV2ForImageClassification(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image, prepared by ConvNextImageProcessor
 | 
			
		||||
    preprocessor = convert_preprocessor(checkpoint_url)
 | 
			
		||||
    inputs = preprocessor(images=prepare_img(), return_tensors="pt")
 | 
			
		||||
    logits = model(**inputs).logits
 | 
			
		||||
 | 
			
		||||
    # note: the logits below were obtained without center cropping
 | 
			
		||||
    if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown URL: {checkpoint_url}")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
 | 
			
		||||
    assert logits.shape == expected_shape
 | 
			
		||||
    print("Model outputs match the original results!")
 | 
			
		||||
 | 
			
		||||
    if save_model:
 | 
			
		||||
        print("Saving model to local...")
 | 
			
		||||
        # Create folder to save model
 | 
			
		||||
        if not os.path.isdir(pytorch_dump_folder_path):
 | 
			
		||||
            os.mkdir(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        preprocessor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    model_name = "convnextv2"
 | 
			
		||||
    if "atto" in checkpoint_url:
 | 
			
		||||
        model_name += "-atto"
 | 
			
		||||
    if "femto" in checkpoint_url:
 | 
			
		||||
        model_name += "-femto"
 | 
			
		||||
    if "pico" in checkpoint_url:
 | 
			
		||||
        model_name += "-pico"
 | 
			
		||||
    if "nano" in checkpoint_url:
 | 
			
		||||
        model_name += "-nano"
 | 
			
		||||
    elif "tiny" in checkpoint_url:
 | 
			
		||||
        model_name += "-tiny"
 | 
			
		||||
    elif "base" in checkpoint_url:
 | 
			
		||||
        model_name += "-base"
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        model_name += "-large"
 | 
			
		||||
    elif "huge" in checkpoint_url:
 | 
			
		||||
        model_name += "-huge"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" not in checkpoint_url:
 | 
			
		||||
        model_name += "-22k"
 | 
			
		||||
    elif "22k" in checkpoint_url and "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-22k-1k"
 | 
			
		||||
    elif "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-1k"
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        model_name += "-224"
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        model_name += "-384"
 | 
			
		||||
    elif "512" in checkpoint_url:
 | 
			
		||||
        model_name += "-512"
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing {model_name} to the hub...")
 | 
			
		||||
        model.push_to_hub(model_name)
 | 
			
		||||
        preprocessor.push_to_hub(model_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default="model",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--save_model", action="store_true", help="Save model to local")
 | 
			
		||||
    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_convnextv2_checkpoint(
 | 
			
		||||
        args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
 | 
			
		||||
    )
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user