Compare commits

...

13 Commits

Author SHA1 Message Date
9c641dc161 v4.54.1 2025-07-29 17:32:03 +02:00
b04aae7398 Fix Cache.max_cache_len max value for Hybrid models (#39737)
* fix gemma

* fix min

* fix quant init issue

* fix gemma 3n

* skip quant cache test

* fix modular

* new test for Gemma

* include cyril change

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
2025-07-29 17:29:33 +02:00
0297e595c1 [modenbert] fix regression (#39750)
* fix regression

* add FA2 test
2025-07-29 17:29:33 +02:00
b8e1b282a9 Fix version issue in modeling_utils.py (#39759)
fix version issue
2025-07-29 17:29:33 +02:00
166c1b47db Fix GPT2 with cross attention (#39754)
* fix

* use new mask API

* style

* fix copies and attention tests

* fix head pruning tests
2025-07-29 17:29:33 +02:00
ab2a6091d9 Fix mamba regression (#39728)
* fix mamba regression

* fix compile test
2025-07-29 17:29:33 +02:00
3a6d13c887 Fix: add back base model plan (#39733)
* Fix: add back base model plan

* Fix: typo

* fixup

* remove unused import

---------

Co-authored-by: Arthur <arthur.zucker@gmail.com>
2025-07-29 17:29:33 +02:00
a033ae4876 fix cache inheritance (#39748)
* fix cache inheritance

* styule
2025-07-29 17:29:33 +02:00
457b478e4a Fix cache-related tests (#39676)
* fix

* fix kyutai at last

* fix unrelated tests and copies

* update musicgen as well

* revert tensor

* fix old test failures

* why it wasn't added?
2025-07-29 17:29:33 +02:00
862cb55017 Fix Layer device placement in Caches (#39732)
* fix device placement

* style

* typo in comment
2025-07-29 17:29:33 +02:00
67fd36fcc8 PATCH: add back n-dim device-mesh + fix tp trainer saving (#39693)
* Feat: something

* Feat: initial changes

* tmp changes to unblock

* Refactor

* remove todo

* Feat: docstring

* Fix: saving of distributed model in trainer

* Fix: distributed saving with trainer

* Feat: add pure tp saving

* Only require tp dim if ndim > 1

* Fix: default to None

* Fix: better comments/errors

* Fix: properly check tp_size attribute

* Fix: properly check for None in tp_size

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-07-29 17:29:33 +02:00
709c6fd008 fix missing model._tp_size from ep refactor (#39688)
* fix missing model._tp_size from ep refactor

* restore setting device_mesh too
2025-07-29 17:29:32 +02:00
3fd456b200 v4.54-release 2025-07-25 20:44:40 +02:00
374 changed files with 458 additions and 76516 deletions

View File

@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets[audio]>=1.14.0",
# "evaluate",
# "librosa",
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate>=0.12.0",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate>=0.12.0",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -56,7 +56,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "torch>=1.5.0",
# "torchvision>=0.6.0",
# "datasets>=1.8.0",
@ -61,7 +61,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "timm",
# "datasets",
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "timm",
# "datasets",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -69,7 +69,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -72,7 +72,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -74,7 +74,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "accelerate >= 0.12.0",
# "torch >= 1.3",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "sentencepiece != 0.1.92",
# "protobuf",
@ -57,7 +57,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = logging.getLogger(__name__)

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "sentencepiece != 0.1.92",
# "protobuf",
@ -65,7 +65,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "timm",
# "datasets>=4.0",
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "albumentations >= 1.4.16",
# "timm",
# "datasets>=4.0",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets >= 2.0.0",
# "torch >= 1.3",
# "accelerate",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets >= 2.0.0",
# "torch >= 1.3",
# "accelerate",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets[audio] >= 1.12.0",
# "torch >= 1.5",
# "torchaudio",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -64,7 +64,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "datasets[audio] >= 1.18.0",
# "torch >= 1.5",
# "torchaudio",
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -14,7 +14,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)

View File

@ -16,7 +16,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -62,7 +62,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -16,7 +16,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.21.0",
# "sentencepiece != 0.1.92",
# "protobuf",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.21.0",
# "sentencepiece != 0.1.92",
# "protobuf",

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "seqeval",
# "datasets >= 1.8.0",
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "seqeval",
# "datasets >= 1.8.0",
@ -67,7 +67,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -66,7 +66,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -15,7 +15,7 @@
# /// script
# dependencies = [
# "transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers==4.54.1",
# "accelerate >= 0.12.0",
# "datasets >= 1.8.0",
# "sentencepiece != 0.1.92",
@ -71,7 +71,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = logging.getLogger(__name__)

View File

@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
logger = logging.getLogger(__name__)

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.54.0.dev0")
check_min_version("4.54.0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -462,7 +462,7 @@ install_requires = [
setup(
name="transformers",
version="4.54.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.54.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.54.0.dev0"
__version__ = "4.54.1"
from pathlib import Path
from typing import TYPE_CHECKING

View File

@ -263,6 +263,14 @@ class StaticLayer(CacheLayerMixin):
key_states = key_states.to(self.keys.dtype)
value_states = value_states.to(self.values.dtype)
# This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
# the device_map. However, even if it is the case, this will only run once, because then the new states received
# will always have the same device
if self.device != key_states.device:
self.device = key_states.device
self.keys = self.keys.to(self.device)
self.values = self.values.to(self.device)
if cache_position is None:
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
self.keys.copy_(key_states)
@ -317,8 +325,9 @@ class SlidingWindowLayer(StaticLayer):
sliding_window (`int`):
Effective window size: number of tokens that are kept on each update call.
"""
kwargs.pop("max_cache_len", None)
super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs)
max_cache_len = kwargs.pop("max_cache_len", None)
max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window
super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs)
def update(
self,
@ -341,6 +350,14 @@ class SlidingWindowLayer(StaticLayer):
if cache_position is None:
raise ValueError("`cache_position` must be provided for SlidingWindowLayer.")
# This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
# the device_map. However, even if it is the case, this will only run once, because then the new states received
# will always have the same device
if self.device != key_states.device:
self.device = key_states.device
self.keys = self.keys.to(self.device)
self.values = self.values.to(self.device)
key_states = key_states.to(self.keys.dtype)
value_states = value_states.to(self.values.dtype)
@ -354,7 +371,7 @@ class SlidingWindowLayer(StaticLayer):
return key_states, value_states
# Sliding window logic for generation phase or prefill < window
slicing = torch.arange(self.max_cache_len, device=value_states.device)
slicing = torch.arange(self.max_cache_len, device=self.device)
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > self.max_cache_len
indices = (slicing + to_shift.sum()) % self.max_cache_len
@ -411,6 +428,14 @@ class ChunkedSlidingLayer(SlidingWindowLayer):
if cache_position is None:
raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.")
# This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
# the device_map. However, even if it is the case, this will only run once, because then the new states received
# will always have the same device
if self.device != key_states.device:
self.device = key_states.device
self.keys = self.keys.to(self.device)
self.values = self.values.to(self.device)
cumulative_length = self.cumulative_length
self.cumulative_length += key_states.shape[-2]
is_full = cumulative_length >= self.max_cache_len
@ -1253,9 +1278,7 @@ class Cache:
def max_cache_len(self) -> int:
"""Return the maximum cache length of the cache"""
values = [layer.max_cache_len for layer in self.layers]
if len(set(values)) > 1:
raise ValueError(f"Max cache length is not consistent across layers: {values}")
return values[0]
return max(values)
@property
def is_compileable(self) -> bool:
@ -1631,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache):
"""
def __init__(self, **kwargs) -> None:
Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)
DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)
class HQQQuantizedCache(QuantizedCache):
@ -1673,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache):
def __init__(self, backend="HQQ", **kwargs) -> None:
assert backend == "HQQ"
Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)
class EncoderDecoderCache(Cache):
@ -1927,10 +1950,6 @@ def parse_layer_args_from_model_config(
)
# Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
max_cache_len = max_cache_len or config.max_position_embeddings
if getattr(config, "sliding_window", None) is not None:
sliding_window_len = min(config.sliding_window, max_cache_len)
else:
sliding_window_len = None
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
head_dim = (
config.head_dim
@ -1957,7 +1976,7 @@ def parse_layer_args_from_model_config(
"layer_device_map": layer_device_map,
"head_dim": head_dim,
"num_heads": num_heads,
"sliding_window": sliding_window_len,
"sliding_window": getattr(config, "sliding_window", None),
}
return {k: v for k, v in layer_args.items() if v is not None}

View File

@ -32,7 +32,6 @@ from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm
from ..cache_utils import Cache
from ..configuration_utils import PretrainedConfig
from ..generation.configuration_utils import GenerationConfig
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@ -153,7 +152,7 @@ class RequestState:
@attach_tracer()
class PagedAttentionCache(Cache):
class PagedAttentionCache:
def __init__(
self,
config: PretrainedConfig,

View File

@ -2055,7 +2055,7 @@ class GenerationMixin(ContinuousMixin):
generation_config.cache_implementation = None
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
self.config.get_text_config(), "cache_implementation", None
self.config.get_text_config(decoder=True), "cache_implementation", None
)
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:

View File

@ -821,8 +821,6 @@ class GroupedGemmParallel(TensorParallelLayer):
param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
if "gate_up" in param_type and False:
param = torch.cat([param[..., ::2], param[..., 1::2]], dim=-1)
return param
@ -1014,7 +1012,8 @@ def shard_and_distribute_module(
"""
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
tp_plan = model._tp_plan or {}
tp_plan.update(getattr(type(model), "_tp_plan", {}))
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
rank = int(rank)
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
@ -1080,9 +1079,14 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
def distribute_model(model, distributed_config, device_mesh, tp_size):
_plan = "_tp_plan"
tp_plan = getattr(model, "_tp_plan", {}).copy()
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
model._tp_plan.update(tp_plan)
model._tp_size = tp_size
model._device_mesh = device_mesh
if distributed_config is not None:
distributed_config = DistributedConfig.from_config(distributed_config)
if isinstance(distributed_config, dict):
distributed_config = DistributedConfig.from_dict(distributed_config)
if distributed_config.enable_expert_parallel:
_plan = "_ep_plan"
model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy()

View File

@ -2881,9 +2881,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(
module, (nn.LayerNorm, nn.RMSNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
)
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
@ -4472,7 +4470,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism
If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@ -4617,10 +4615,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if device_mesh is None:
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
else:
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim > 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
if tp_size is None:
tp_size = torch.distributed.get_world_size()

View File

@ -1,269 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
from typing import Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
Aimv2Config,
Aimv2Model,
Aimv2VisionConfig,
Aimv2VisionModel,
AutoImageProcessor,
AutoProcessor,
)
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL = {
# Embeddings
r"preprocessor.patchifier.proj": r"embeddings.patch_embed",
r"preprocessor.pos_embed": r"embeddings.position_embedding.weight",
r"preprocessor.patchifier.norm.weight": r"embeddings.rms_norm.weight",
# Encoder Layers
r"trunk.blocks.(\d+).attn.qkv": r"encoder.layers.\1.attention.qkv",
r"trunk.blocks.(\d+).attn.proj": r"encoder.layers.\1.attention.out_proj",
r"trunk.blocks.(\d+).mlp.fc1": r"encoder.layers.\1.ffn.gate_proj",
r"trunk.blocks.(\d+).mlp.fc2": r"encoder.layers.\1.ffn.down_proj",
r"trunk.blocks.(\d+).mlp.fc3": r"encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"trunk.blocks.(\d+).norm_1": r"encoder.layers.\1.rms_norm1",
r"trunk.blocks.(\d+).norm_2": r"encoder.layers.\1.rms_norm2",
# Final Norm
r"trunk.post_trunk_norm": r"rms_norm",
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision Embeddings
r"image_encoder.preprocessor.patchifier.proj": r"vision_model.embeddings.patch_embed",
r"image_encoder.preprocessor.pos_embed": r"vision_model.embeddings.position_embedding.weight",
r"image_encoder.preprocessor.patchifier.norm.weight": r"vision_model.embeddings.rms_norm.weight",
# Vision Encoder Layers
r"image_encoder.trunk.blocks.(\d+).attn.qkv": r"vision_model.encoder.layers.\1.attention.qkv",
r"image_encoder.trunk.blocks.(\d+).attn.proj": r"vision_model.encoder.layers.\1.attention.out_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc1": r"vision_model.encoder.layers.\1.ffn.gate_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc2": r"vision_model.encoder.layers.\1.ffn.down_proj",
r"image_encoder.trunk.blocks.(\d+).mlp.fc3": r"vision_model.encoder.layers.\1.ffn.up_proj",
# Normalization Layers
r"image_encoder.trunk.blocks.(\d+).norm_1": r"vision_model.encoder.layers.\1.rms_norm1",
r"image_encoder.trunk.blocks.(\d+).norm_2": r"vision_model.encoder.layers.\1.rms_norm2",
r"image_encoder.trunk.post_trunk_norm": r"vision_model.rms_norm",
r"image_projector": r"visual_projection",
# Vision Head
r"image_encoder.head.cls_token": r"vision_model.head.cls_token",
r"image_encoder.head.k": r"vision_model.head.k_proj",
r"image_encoder.head.v": r"vision_model.head.v_proj",
r"image_encoder.head.linear": r"vision_model.head.output_proj",
# Text Embeddings
r"text_encoder.preprocessor.text_embedding.weight": r"text_model.embeddings.token_embedding.weight",
r"text_encoder.preprocessor.positional_embedding": r"text_model.embeddings.position_embedding.weight",
# Text Encoder Layers
r"text_encoder.trunk.blocks.(\d+).attn.qkv": r"text_model.encoder.layers.\1.attention.qkv",
r"text_encoder.trunk.blocks.(\d+).attn.proj": r"text_model.encoder.layers.\1.attention.out_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc1": r"text_model.encoder.layers.\1.ffn.gate_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc2": r"text_model.encoder.layers.\1.ffn.down_proj",
r"text_encoder.trunk.blocks.(\d+).mlp.fc3": r"text_model.encoder.layers.\1.ffn.up_proj",
# Text Normalization Layers
r"text_encoder.trunk.blocks.(\d+).norm_1": r"text_model.encoder.layers.\1.rms_norm1",
r"text_encoder.trunk.blocks.(\d+).norm_2": r"text_model.encoder.layers.\1.rms_norm2",
r"text_encoder.trunk.post_trunk_norm": r"text_model.rms_norm",
r"text_projector": r"text_projection",
r"log_logit_scale": r"logit_scale",
}
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
# Download only the model.safetensors file
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["model.safetensors"],
)
original_state_dict = {}
safetensor_path = f"{directory_path}/model.safetensors"
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_old_keys_to_new_keys(state_dict_keys: dict, ORIGINAL_TO_CONVERTED_KEY_MAPPING: dict):
"""Converts state dict keys from the old format to the new format."""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
def split_qkv_tensor(key, tensor):
"""Splits a qkv tensor into separate q, k, v tensors and updates the key accordingly."""
new_keys = ["q_proj", "k_proj", "v_proj"]
split_size = tensor.shape[0] // 3
split_tensors = torch.split(tensor, split_size, dim=0)
return {key.replace("qkv", new_key): split_tensors[i] for i, new_key in enumerate(new_keys)}
def get_model_config_mapping(model_id: str):
"""Determines the correct model, config, and key mappings based on the checkpoint name."""
if model_id == "apple/aimv2-large-patch14-224-lit":
return Aimv2Model, Aimv2Config, ORIGINAL_TO_CONVERTED_KEY_MAPPING
else:
return Aimv2VisionModel, Aimv2VisionConfig, ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION_MODEL
def write_model(
hf_repo_id: str,
output_dir: str,
safe_serialization: bool = True,
):
"""
Converts a model checkpoint to Hugging Face format and saves it.
Args:
hf_repo_id (str): The Hugging Face repo ID to load from.
output_dir (str): The directory to save the converted model.
safe_serialization (bool): Whether to use safe serialization.
Returns:
model: The reloaded Hugging Face model.
"""
os.makedirs(output_dir, exist_ok=True)
# Get the appropriate model, config, and key mapping
model_class, config_class, key_mapping = get_model_config_mapping(hf_repo_id)
# Load config and original state dict
config = config_class.from_pretrained(hf_repo_id)
# Checkpoint `apple/aimv2-large-patch14-224-lit` uses AttentionPoolingHead hence set the required attr in config.
if hf_repo_id != "apple/aimv2-large-patch14-224-lit":
config.use_head = False
if hf_repo_id == "apple/aimv2-large-patch14-native":
config.is_native = True
original_state_dict = load_original_state_dict(hf_repo_id)
print("Converting model...")
state_dict = {}
result = convert_old_keys_to_new_keys(original_state_dict, key_mapping)
all_keys = list(original_state_dict.keys())
for key in all_keys:
value = original_state_dict[key]
new_key = result.pop(key)
if "qkv" in new_key:
qkv_state_dict = split_qkv_tensor(new_key, value)
state_dict.update(qkv_state_dict)
else:
state_dict[new_key] = value
# Check if position embeddings exist before squeezing
if new_key.endswith("position_embedding.weight"):
state_dict[new_key] = value.squeeze(0)
print(f"Loading the checkpoint in a {model_class.__name__}.")
model = model_class(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
model = model_class.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")
return model
def write_image_processor(hf_repo_id: str, output_dir: str):
if hf_repo_id == "apple/aimv2-large-patch14-224-lit":
image_processor = AutoProcessor.from_pretrained(hf_repo_id, use_fast=True)
else:
image_processor = AutoImageProcessor.from_pretrained(hf_repo_id, use_fast=True)
image_processor.save_pretrained(output_dir)
return image_processor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_repo_id",
default="apple/aimv2-large-patch14-224",
help="Location of official weights from apple on HF",
)
parser.add_argument(
"--output_dir",
default="aimv2_model",
help="Location to write the converted model and processor",
)
parser.add_argument(
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--push_to_hub",
action=argparse.BooleanOptionalAction,
help="Whether or not to push the converted model to the huggingface hub.",
)
parser.add_argument(
"--hub_repo_id",
default=None,
help="Huggingface hub repo to write the converted model and processor",
)
args = parser.parse_args()
model = write_model(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
image_processor = write_image_processor(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
)
if args.push_to_hub:
print("Pushing to hub...")
model.push_to_hub(args.hub_repo_id)
image_processor.push_to_hub(args.hub_repo_id)
if __name__ == "__main__":
main()

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
import argparse
import torch
from ...utils import logging
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_json_file(albert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = AlbertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--albert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)

View File

@ -1,389 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALIGN checkpoints from the original repository."""
import argparse
import os
import align
import numpy as np
import requests
import tensorflow as tf
import torch
from PIL import Image
from tokenizer import Tokenizer
from transformers import (
AlignConfig,
AlignModel,
AlignProcessor,
BertConfig,
BertTokenizer,
EfficientNetConfig,
EfficientNetImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def preprocess(image):
image = tf.image.resize(image, (346, 346))
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
return image
def get_align_config():
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
vision_config.image_size = 289
vision_config.hidden_dim = 640
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
vision_config.depthwise_padding = []
text_config = BertConfig()
config = AlignConfig.from_text_vision_configs(
text_config=text_config, vision_config=vision_config, projection_dim=640
)
return config
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def get_processor():
image_processor = EfficientNetImageProcessor(
do_center_crop=True,
rescale_factor=1 / 127.5,
rescale_offset=True,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
tokenizer.model_max_length = 64
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
return processor
# here we list all keys to be renamed (original name on the left, our name on the right)
def rename_keys(original_param_names):
# EfficientNet image encoder
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
block_names = list(set(block_names))
block_names = sorted(block_names)
num_blocks = len(block_names)
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
rename_keys = []
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
for b in block_names:
hf_b = block_name_mapping[b]
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
rename_keys.append(
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
)
rename_keys.append(
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
)
rename_keys.append(
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
)
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
rename_keys.append(
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
)
rename_keys.append(
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
)
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
rename_keys.append(
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
)
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
rename_keys.append(
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
)
rename_keys.append(
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
)
key_mapping = {}
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = "vision_model." + item[1]
# BERT text encoder
rename_keys = []
old = "tf_bert_model/bert"
new = "text_model"
for i in range(12):
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.query.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
f"{new}.encoder.layer.{i}.attention.self.query.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.key.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
f"{new}.encoder.layer.{i}.attention.self.key.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
f"{new}.encoder.layer.{i}.attention.self.value.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
f"{new}.encoder.layer.{i}.attention.self.value.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
)
)
rename_keys.append(
(
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
)
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
)
rename_keys.append(
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
)
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
rename_keys.append(
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
)
rename_keys.append(
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
)
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("dense/bias:0", "text_projection.bias"))
rename_keys.append(("temperature:0", "temperature"))
for item in rename_keys:
if item[0] in original_param_names:
key_mapping[item[0]] = item[1]
return key_mapping
def replace_params(hf_params, tf_params, key_mapping):
list(hf_params.keys())
for key, value in tf_params.items():
if key not in key_mapping:
continue
hf_key = key_mapping[key]
if "_conv" in key and "kernel" in key:
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
elif "embeddings" in key:
new_hf_value = torch.from_numpy(value)
elif "depthwise_kernel" in key:
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
elif "kernel" in key:
new_hf_value = torch.from_numpy(np.transpose(value))
elif "temperature" in key:
new_hf_value = value
elif "bn/gamma" or "bn/beta" in key:
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
else:
new_hf_value = torch.from_numpy(value)
# Replace HF parameters with original TF model parameters
hf_params[hf_key].copy_(new_hf_value)
@torch.no_grad()
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
"""
Copy/paste/tweak model's weights to our ALIGN structure.
"""
# Load original model
seq_length = 64
tok = Tokenizer(seq_length)
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
original_model.compile()
original_model.load_weights(checkpoint_path)
tf_params = original_model.trainable_variables
tf_non_train_params = original_model.non_trainable_variables
tf_params = {param.name: param.numpy() for param in tf_params}
for param in tf_non_train_params:
tf_params[param.name] = param.numpy()
tf_param_names = list(tf_params.keys())
# Load HuggingFace model
config = get_align_config()
hf_model = AlignModel(config).eval()
hf_params = hf_model.state_dict()
# Create src-to-dst parameter name mapping dictionary
print("Converting parameters...")
key_mapping = rename_keys(tf_param_names)
replace_params(hf_params, tf_params, key_mapping)
# Initialize processor
processor = get_processor()
inputs = processor(
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
)
# HF model inference
hf_model.eval()
with torch.no_grad():
outputs = hf_model(**inputs)
hf_image_features = outputs.image_embeds.detach().numpy()
hf_text_features = outputs.text_embeds.detach().numpy()
# Original model inference
original_model.trainable = False
tf_image_processor = EfficientNetImageProcessor(
do_center_crop=True,
do_rescale=False,
do_normalize=False,
include_top=False,
resample=Image.BILINEAR,
)
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
text = tok(tf.constant(["A picture of a cat"]))
image_features = original_model.image_encoder(image, training=False)
text_features = original_model.text_encoder(text, training=False)
image_features = tf.nn.l2_normalize(image_features, axis=-1)
text_features = tf.nn.l2_normalize(text_features, axis=-1)
# Check whether original and HF model outputs match -> np.allclose
if not np.allclose(image_features, hf_image_features, atol=1e-3):
raise ValueError("The predicted image features are not the same.")
if not np.allclose(text_features, hf_text_features, atol=1e-3):
raise ValueError("The predicted text features are not the same.")
print("Model outputs match!")
if save_model:
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
# Save converted model and image processor
hf_model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Push model and image processor to hub
print("Pushing converted ALIGN to the hub...")
processor.push_to_hub("align-base")
hf_model.push_to_hub("align-base")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_path",
default="./weights/model-weights",
type=str,
help="Path to the pretrained TF ALIGN checkpoint.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="hf_model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
args = parser.parse_args()
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)

View File

@ -1,162 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
AddedToken,
AriaForConditionalGeneration,
AriaProcessor,
AutoConfig,
AutoTokenizer,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
Example for creating the old state dict file with Python:
import torch
from aria.model.language_model.aria_llama import AriaTextForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"vision_tower.vision_model": "vision_tower",
"ln_ffn": "layer_norm",
"ffn": "feed_forward",
"ln_kv": "layer_norm_kv",
}
def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
return new_state_dict
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
extra_special_tokens={
"image_token": "<|img|>",
"pad_token": "<pad>",
},
)
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
processor = AriaProcessor.from_pretrained(
text_model_id,
tokenizer=tokenizer,
)
config = AutoConfig.from_pretrained(text_model_id)
config.vision_config.hidden_size = 1152
config.vision_config.attention_heads = 16
config.pad_token_id = 2
config.image_token_id = 9
config.intermediate_size = config.moe_intermediate_size
config.auto_map = {
"AutoConfig": "modeling_aria.AriaConfig",
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
}
with torch.device("meta"):
model = AriaForConditionalGeneration(config)
state_dict = load_original_state_dict(old_state_dict_id)
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)
# print("Saving models")
# model.save_pretrained("local_aria", safe_serialization=False)
# processor.save_pretrained("local_aria")
print("Pushing to hub")
model.push_to_hub(output_hub_path, create_pr=True)
processor.push_to_hub(output_hub_path, create_pr=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
default="rhymes-ai/Aria",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="rhymes-ai/Aria",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="rhymes-ai/Aria",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="rhymes-ai/Aria",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
if __name__ == "__main__":
main()

View File

@ -1,279 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
import argparse
import json
from pathlib import Path
import torch
import torchaudio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_audio_spectrogram_transformer_config(model_name):
config = ASTConfig()
if "10-10" in model_name:
pass
elif "speech-commands" in model_name:
config.max_length = 128
elif "12-12" in model_name:
config.time_stride = 12
config.frequency_stride = 12
elif "14-14" in model_name:
config.time_stride = 14
config.frequency_stride = 14
elif "16-16" in model_name:
config.time_stride = 16
config.frequency_stride = 16
else:
raise ValueError("Model not supported")
repo_id = "huggingface/label-files"
if "speech-commands" in model_name:
config.num_labels = 35
filename = "speech-commands-v2-id2label.json"
else:
config.num_labels = 527
filename = "audioset-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def rename_key(name):
if "module.v" in name:
name = name.replace("module.v", "audio_spectrogram_transformer")
if "cls_token" in name:
name = name.replace("cls_token", "embeddings.cls_token")
if "dist_token" in name:
name = name.replace("dist_token", "embeddings.distillation_token")
if "pos_embed" in name:
name = name.replace("pos_embed", "embeddings.position_embeddings")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
# transformer blocks
if "blocks" in name:
name = name.replace("blocks", "encoder.layer")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
# final layernorm
if "audio_spectrogram_transformer.norm" in name:
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
# classifier head
if "module.mlp_head.0" in name:
name = name.replace("module.mlp_head.0", "classifier.layernorm")
if "module.mlp_head.1" in name:
name = name.replace("module.mlp_head.1", "classifier.dense")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[3])
dim = config.hidden_size
if "weight" in key:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
] = val[:dim, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
] = val[dim : dim * 2, :]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
] = val[:dim]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
] = val[dim : dim * 2]
orig_state_dict[
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def remove_keys(state_dict):
ignore_keys = [
"module.v.head.weight",
"module.v.head.bias",
"module.v.head_dist.weight",
"module.v.head_dist.bias",
]
for k in ignore_keys:
state_dict.pop(k, None)
@torch.no_grad()
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
"""
config = get_audio_spectrogram_transformer_config(model_name)
model_name_to_url = {
"ast-finetuned-audioset-10-10-0.4593": (
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.450": (
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448": (
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
),
"ast-finetuned-audioset-10-10-0.448-v2": (
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
),
"ast-finetuned-audioset-12-12-0.447": (
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
),
"ast-finetuned-audioset-14-14-0.443": (
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
),
"ast-finetuned-audioset-16-16-0.442": (
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
),
"ast-finetuned-speech-commands-v2": (
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
),
}
# load original state_dict
checkpoint_url = model_name_to_url[model_name]
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
# remove some keys
remove_keys(state_dict)
# rename some keys
new_state_dict = convert_state_dict(state_dict, config)
# load 🤗 model
model = ASTForAudioClassification(config)
model.eval()
model.load_state_dict(new_state_dict)
# verify outputs on dummy input
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
max_length = 1024 if "speech-commands" not in model_name else 128
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
if "speech-commands" in model_name:
# TODO: Convert dataset to Parquet
dataset = load_dataset("google/speech_commands", "v0.02", split="validation")
waveform = dataset[0]["audio"]["array"]
else:
filepath = hf_hub_download(
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
filename="sample_audio.flac",
repo_type="dataset",
)
waveform, _ = torchaudio.load(filepath)
waveform = waveform.squeeze().numpy()
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
if model_name == "ast-finetuned-audioset-10-10-0.4593":
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
elif model_name == "ast-finetuned-audioset-10-10-0.450":
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
elif model_name == "ast-finetuned-audioset-10-10-0.448":
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
elif model_name == "ast-finetuned-audioset-12-12-0.447":
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
elif model_name == "ast-finetuned-audioset-14-14-0.443":
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
elif model_name == "ast-finetuned-audioset-16-16-0.442":
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
elif model_name == "ast-finetuned-speech-commands-v2":
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
else:
raise ValueError("Unknown model name")
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
raise ValueError("Logits don't match")
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and feature extractor to the hub...")
model.push_to_hub(f"MIT/{model_name}")
feature_extractor.push_to_hub(f"MIT/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ast-finetuned-audioset-10-10-0.4593",
type=str,
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,273 +0,0 @@
# coding=utf-8
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse
import json
import os
import re
from os import path
from typing import Optional, Union
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from transformers import AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from .configuration_bamba import BambaConfig
def convert_state_dict_from_mamba_ssm(original_sd: dict) -> dict[str, torch.Tensor]:
state_dict = {}
for orig_k, param in original_sd.items():
k = orig_k.replace("backbone", "model")
# for embeddings
k = k.replace("embedding", "embed_tokens")
# for mixer
k = k.replace("mixer", "mamba")
# for final layernorm
k = k.replace("norm_f", "final_layernorm")
# for block layernorm
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
# for mlp
k = k.replace("mlp.fc2", "feed_forward.down_proj")
if "mlp.fc1" in k:
param, param2 = torch.chunk(param, 2, dim=0)
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
state_dict[k2] = param2
k = k.replace("mlp.fc1", "feed_forward.up_proj")
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
):
# then this must be a mamba
pass
else:
# for attn
# - because mixer was replaced to mamba above
k = k.replace("mamba.out_proj", "self_attn.o_proj")
if "mamba.in_proj" in k:
m, n = param.shape
d = (m - n) // 2
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
state_dict[k2] = param2
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
state_dict[k2] = param3
k = k.replace("mamba.in_proj", "self_attn.q_proj")
state_dict[k] = param
return state_dict
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_ssm_config_to_hf_config(
config_ssm: dict,
**kwargs,
) -> BambaConfig:
"""Convert a config from mamba_ssm to a BambaConfig from here."""
hf_config: BambaConfig = BambaConfig(**kwargs)
hf_config.architectures = ["BambaForCausalLM"]
# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm["d_model"]
hf_config.intermediate_size = config_ssm["d_intermediate"]
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
hf_config.num_hidden_layers = config_ssm["n_layer"]
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
# currently this script assumes config_ssm belongs to v2
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
raise ValueError("Conversion script only supports Mamba2")
# Set attention values
attn_cfg = config_ssm.get("attn_cfg")
if attn_cfg:
assert attn_cfg["causal"], "Only support non-causal attention."
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
hf_config.num_attention_heads = attn_cfg["num_heads"]
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
attention_layer_indices = config_ssm.get("attn_layer_idx")
if attention_layer_indices:
hf_config.attn_layer_indices = attention_layer_indices
# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size
return hf_config
def save_single_safetensor(
state_dict: dict,
save_directory: str,
metadata: dict,
):
save_file(
state_dict,
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
metadata,
)
def save_sharded_safetensors(
state_dict: dict,
save_directory: str,
metadata: dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
mamba_ssm_checkpoint_path: str,
precision: str,
output_dir: str,
tokenizer_path: Optional[str] = None,
save_model: Union[bool, str] = True,
) -> None:
# load tokenizer if provided, this will be used to set the
# token_ids in the config file
token_ids = {}
if tokenizer_path:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
for key in [
"bos_token_id",
"eos_token_id",
"pad_token_id",
]:
id = getattr(tokenizer, key, None)
if id:
token_ids[key] = id
# there are some configs unsettable by mamba_ssn config, so
# if there are changes from the defaults, have to pass them into
# the function
unsettables = {
"mamba_d_head": 64,
"mamba_d_state": 128,
"mamba_n_groups": 1,
"rms_norm_eps": 1e-5,
}
# Load and save config based on name
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
# convert the config
hf_config = convert_ssm_config_to_hf_config(
config_ssm=config,
**token_ids,
**unsettables,
)
hf_config.save_pretrained(output_dir)
# Load state dict of the original model and transfer to hf model
state_dict = torch.load(
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)
# FIXME: allow other parameters to pass in
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
# Save new model to pytorch_dump_path
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_file_fn = None
if isinstance(save_model, bool) and save_model:
save_file_fn = save_single_safetensor
elif isinstance(save_model, str) and save_model == "sharded":
save_file_fn = save_sharded_safetensors
if save_file_fn:
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba_ssm_checkpoint_directory",
type=str,
required=True,
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
required=True,
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a the tokenizer file.",
)
args = parser.parse_args()
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
args.mamba_ssm_checkpoint_directory,
args.precision,
args.output_dir,
save_model="sharded",
)

View File

@ -1,263 +0,0 @@
"""Convert Bark checkpoint."""
import argparse
import os
from pathlib import Path
import torch
from bark.generation import _load_model as _bark_load_model
from huggingface_hub import hf_hub_download
from transformers import EncodecConfig, EncodecModel, set_seed
from transformers.models.bark.configuration_bark import (
BarkCoarseConfig,
BarkConfig,
BarkFineConfig,
BarkSemanticConfig,
)
from transformers.models.bark.generation_configuration_bark import (
BarkCoarseGenerationConfig,
BarkFineGenerationConfig,
BarkGenerationConfig,
BarkSemanticGenerationConfig,
)
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
set_seed(770)
new_layer_name_dict = {
"c_attn": "att_proj",
"c_proj": "out_proj",
"c_fc": "in_proj",
"transformer.": "",
"h.": "layers.",
"ln_1": "layernorm_1",
"ln_2": "layernorm_2",
"ln_f": "layernorm_final",
"wpe": "position_embeds_layer",
"wte": "input_embeds_layer",
}
REMOTE_MODEL_PATHS = {
"text_small": {
"repo_id": "suno/bark",
"file_name": "text.pt",
},
"coarse_small": {
"repo_id": "suno/bark",
"file_name": "coarse.pt",
},
"fine_small": {
"repo_id": "suno/bark",
"file_name": "fine.pt",
},
"text": {
"repo_id": "suno/bark",
"file_name": "text_2.pt",
},
"coarse": {
"repo_id": "suno/bark",
"file_name": "coarse_2.pt",
},
"fine": {
"repo_id": "suno/bark",
"file_name": "fine_2.pt",
},
}
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
def _get_ckpt_path(model_type, use_small=False):
key = model_type
if use_small:
key += "_small"
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
def _download(from_hf_path, file_name):
os.makedirs(CACHE_DIR, exist_ok=True)
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if model_type == "text":
ModelClass = BarkSemanticModel
ConfigClass = BarkSemanticConfig
GenerationConfigClass = BarkSemanticGenerationConfig
elif model_type == "coarse":
ModelClass = BarkCoarseModel
ConfigClass = BarkCoarseConfig
GenerationConfigClass = BarkCoarseGenerationConfig
elif model_type == "fine":
ModelClass = BarkFineModel
ConfigClass = BarkFineConfig
GenerationConfigClass = BarkFineGenerationConfig
else:
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"])
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
# convert Bark model arguments to HF Bark model arguments
model_args["num_heads"] = model_args.pop("n_head")
model_args["hidden_size"] = model_args.pop("n_embd")
model_args["num_layers"] = model_args.pop("n_layer")
model_config = ConfigClass(**checkpoint["model_args"])
model = ModelClass(config=model_config)
model_generation_config = GenerationConfigClass()
model.generation_config = model_generation_config
state_dict = checkpoint["model"]
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k in state_dict:
if k.startswith(unwanted_prefix):
# replace part of the key with corresponding layer name in HF implementation
new_k = k[len(unwanted_prefix) :]
for old_layer_name, new_layer_name in new_layer_name_dict.items():
new_k = new_k.replace(old_layer_name, new_layer_name)
state_dict[new_k] = state_dict.pop(k)
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.num_parameters(exclude_embeddings=True)
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
return model
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
device = "cpu" # do conversion on cpu
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
# load bark initial model
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
if model_type == "text":
bark_model = bark_model["model"]
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
raise ValueError("initial and new models don't have the same number of parameters")
# check if same output as the bark model
batch_size = 5
sequence_length = 10
if model_type in ["text", "coarse"]:
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
output_old_model = bark_model(vec)[0]
output_new_model_total = model(vec)
# take last logits
output_new_model = output_new_model_total.logits[:, [-1], :]
else:
prediction_codebook_channel = 3
n_codes_total = 8
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
output_new_model_total = model(prediction_codebook_channel, vec)
output_old_model = bark_model(prediction_codebook_channel, vec)
output_new_model = output_new_model_total.logits
# output difference should come from the difference of self-attention implementation design
if output_new_model.shape != output_old_model.shape:
raise ValueError("initial and new outputs don't have the same shape")
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
raise ValueError("initial and new outputs are not equal")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
def load_whole_bark_model(
semantic_path,
coarse_path,
fine_path,
append_text,
hub_path,
folder_path,
):
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
semantic = BarkSemanticModel.from_pretrained(semantic_path)
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
bark_config = BarkConfig.from_sub_model_configs(
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
)
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
)
bark = BarkModel(bark_config)
bark.semantic = semantic
bark.coarse_acoustics = coarseAcoustic
bark.fine_acoustics = fineAcoustic
bark.codec_model = codec
bark.generation_config = bark_generation_config
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
args = parser.parse_args()
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)

View File

@ -1,156 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BART checkpoint."""
import argparse
import os
from pathlib import Path
import fairseq
import torch
from packaging import version
from torch import nn
from transformers import (
BartConfig,
BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_TEXT = " Hello world! cécé herlolip"
mnli_rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def load_xsum_checkpoint(checkpoint_path):
"""Checkpoint path should end in model.pt"""
sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
hub_interface.model.load_state_dict(sd["model"])
return hub_interface
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
if not os.path.exists(checkpoint_path):
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
else:
bart = load_xsum_checkpoint(checkpoint_path)
bart.model.upgrade_state_dict(bart.model.state_dict())
if hf_checkpoint_name is None:
hf_checkpoint_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
if not torch.eq(tokens, tokens2).all():
raise ValueError(
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
)
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in mnli_rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
new_model_outputs = model(tokens)[0] # logits
else: # no classification heads to worry about
state_dict = bart.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
fairseq_output = bart.extract_features(tokens)
if hf_checkpoint_name == "facebook/bart-large":
model = BartModel(config).eval()
model.load_state_dict(state_dict)
new_model_outputs = model(tokens).model[0]
else:
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = make_linear_from_emb(model.model.shared)
new_model_outputs = model.model(tokens)[0]
# Check results
if fairseq_output.shape != new_model_outputs.shape:
raise ValueError(
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
)
if (fairseq_output != new_model_outputs).any().item():
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
)
args = parser.parse_args()
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)

View File

@ -1,373 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BEiT checkpoints from the unilm repository."""
import argparse
import json
from pathlib import Path
import requests
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
BeitConfig,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitImageProcessor,
)
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
)
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
]
)
if has_lm_head:
# mask token + shared relative position bias + layernorm
rename_keys.extend(
[
("mask_token", "beit.embeddings.mask_token"),
(
"rel_pos_bias.relative_position_bias_table",
"beit.encoder.relative_position_bias.relative_position_bias_table",
),
(
"rel_pos_bias.relative_position_index",
"beit.encoder.relative_position_bias.relative_position_index",
),
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else:
# layernorm + classification head
rename_keys.extend(
[
("fc_norm.weight", "beit.pooler.layernorm.weight"),
("fc_norm.bias", "beit.pooler.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers):
prefix = "backbone." if is_semantic else ""
# queries, keys and values
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
# relative_position bias table + index
if not has_lm_head:
# each layer has its own relative position bias
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table
state_dict[
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BEiT structure.
"""
# define default BEiT configuration
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
# masked image modeling
config.use_shared_relative_position_bias = True
config.use_mask_token = True
has_lm_head = True
elif checkpoint_url[-9:-4] == "ft22k":
# intermediate fine-tuning on ImageNet-22k
config.use_relative_position_bias = True
config.num_labels = 21841
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
# this dataset contains 21843 labels but the model only has 21841
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
del id2label[9205]
del id2label[15027]
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
elif checkpoint_url[-8:-4] == "to1k":
# fine-tuning on ImageNet-1k
config.use_relative_position_bias = True
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "384" in checkpoint_url:
config.image_size = 384
if "512" in checkpoint_url:
config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
# size of the architecture
if "base" in checkpoint_url:
pass
elif "large" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val
# load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else:
model = BeitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# Check outputs on an image
if is_semantic:
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
else:
image_processor = BeitImageProcessor(
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
)
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
logits = outputs.logits
# verify logits
expected_shape = torch.Size([1, 1000])
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
expected_shape = torch.Size([1, 196, 8192])
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
expected_class_idx = 2397
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
expected_shape = torch.Size([1, 21841])
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
expected_class_idx = 2396
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
expected_class_idx = 285
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
expected_class_idx = 281
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else:
raise ValueError("Can't verify logits as model is not supported")
if logits.shape != expected_shape:
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
if not has_lm_head:
if is_semantic:
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
else:
print("Predicted class idx:", logits.argmax(-1).item())
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
if logits.argmax(-1).item() != expected_class_idx:
raise ValueError("Predicted class index not as expected")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
type=str,
help="URL to the original PyTorch checkpoint (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

View File

@ -1,246 +0,0 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
weight names to the original names, so the model can be imported with Huggingface/transformer.
You may adapt this script to include classification/MLM/NSP/etc. heads.
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
Models trained with never versions are not compatible with this script.
"""
import argparse
import os
import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info(f"Loading TF weight {name} with shape {shape}")
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
" heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError(f"Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
f" {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
# Instantiate model
logger.info(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertModel(config)
# Load weights from checkpoint
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
# Save pytorch-model
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,62 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,112 +0,0 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
import argparse
import os
import numpy as np
import tensorflow as tf
import torch
from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
"""
Args:
model: BertModel Pytorch model instance to be converted
ckpt_dir: Tensorflow model directory
model_name: model name
Currently supported HF models:
- Y BertModel
- N BertForMaskedLM
- N BertForPreTraining
- N BertForMultipleChoice
- N BertForNextSentencePrediction
- N BertForSequenceClassification
- N BertForQuestionAnswering
"""
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
state_dict = model.state_dict()
def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return f"bert/{name}"
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var
tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
if any(x in var_name for x in tensors_to_transpose):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
tf_weight = session.run(tf_var)
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
parser.add_argument(
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
)
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path, weights_only=True),
cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__":
main()

View File

@ -1,188 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
"""
import argparse
import tensorflow as tf
import torch
from transformers import BertConfig, BertForMaskedLM
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertPooler,
BertSelfAttention,
BertSelfOutput,
)
from transformers.utils import logging
logging.set_verbosity_info()
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
def get_masked_lm_array(name: str):
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_array(name: str):
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_layer_array(layer_index: int, name: str):
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
def get_encoder_attention_layer_array(layer_index: int, name: str, original_shape):
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
array = tf.train.load_variable(tf_checkpoint_path, full_name)
array = array.reshape(original_shape)
if "kernel" in name:
array = array.transpose()
return torch.from_numpy(array)
print(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertForMaskedLM(config)
# Layers
for layer_index in range(0, config.num_hidden_layers):
layer: BertLayer = model.bert.encoder.layer[layer_index]
# Self-attention
self_attn: BertSelfAttention = layer.attention.self
self_attn.query.weight.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
)
self_attn.query.bias.data = get_encoder_attention_layer_array(
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
)
self_attn.key.weight.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
)
self_attn.key.bias.data = get_encoder_attention_layer_array(
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
)
self_attn.value.weight.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
)
self_attn.value.bias.data = get_encoder_attention_layer_array(
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
)
# Self-attention Output
self_output: BertSelfOutput = layer.attention.output
self_output.dense.weight.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
)
self_output.dense.bias.data = get_encoder_attention_layer_array(
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
)
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
# Intermediate
intermediate: BertIntermediate = layer.intermediate
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
# Output
bert_output: BertOutput = layer.output
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
# Embeddings
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
# LM Head
lm_head = model.cls.predictions.transform
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
# Pooling
model.bert.pooler = BertPooler(config=config)
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
# Export final model
model.save_pretrained(pytorch_dump_path)
# Integration test - should load without any errors ;)
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
print(new_model.eval())
print("Model conversion was done successfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

View File

@ -1,69 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigBird checkpoint."""
import argparse
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
# Initialise PyTorch model
config = BigBirdConfig.from_json_file(big_bird_config_file)
print(f"Building PyTorch model from configuration: {config}")
if is_trivia_qa:
model = BigBirdForQuestionAnswering(config)
else:
model = BigBirdForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--big_bird_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
)

View File

@ -1,169 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
INIT_COMMON = [
# tf -> hf
("/", "."),
("layer_", "layers."),
("kernel", "weight"),
("beta", "bias"),
("gamma", "weight"),
("pegasus", "model"),
]
END_COMMON = [
(".output.dense", ".fc2"),
("intermediate.LayerNorm", "final_layer_norm"),
("intermediate.dense", "fc1"),
]
DECODER_PATTERNS = (
INIT_COMMON
+ [
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.out_proj"),
("attention.self", "self_attn"),
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
("attention.encdec_output.dense", "encoder_attn.out_proj"),
("attention.encdec", "encoder_attn"),
("key", "k_proj"),
("value", "v_proj"),
("query", "q_proj"),
("decoder.LayerNorm", "decoder.layernorm_embedding"),
]
+ END_COMMON
)
REMAINING_PATTERNS = (
INIT_COMMON
+ [
("embeddings.word_embeddings", "shared.weight"),
("embeddings.position_embeddings", "embed_positions.weight"),
("attention.self.LayerNorm", "self_attn_layer_norm"),
("attention.output.dense", "self_attn.output"),
("attention.self", "self_attn.self"),
("encoder.LayerNorm", "encoder.layernorm_embedding"),
]
+ END_COMMON
)
KEYS_TO_IGNORE = [
"encdec/key/bias",
"encdec/query/bias",
"encdec/value/bias",
"self/key/bias",
"self/query/bias",
"self/value/bias",
"encdec_output/dense/bias",
"attention/output/dense/bias",
]
def rename_state_dict_key(k, patterns):
for tf_name, hf_name in patterns:
k = k.replace(tf_name, hf_name)
return k
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
cfg = BigBirdPegasusConfig(**config_update)
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
state_dict = torch_model.state_dict()
mapping = {}
# separating decoder weights
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = DECODER_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
if any(conditions):
continue
patterns = REMAINING_PATTERNS
new_k = rename_state_dict_key(k, patterns)
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
v = v.T
mapping[new_k] = torch.from_numpy(v)
if k != "pegasus/embeddings/position_embeddings":
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
missing, extra = torch_model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k
for k in missing
if k
not in [
"final_logits_bias",
"model.encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"lm_head.weight",
]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return torch_model
def get_tf_weights_as_numpy(path) -> dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any(pat in name for pat in ignore_name)
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
tf_weights = get_tf_weights_as_numpy(ckpt_path)
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
config_update = {}
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)

View File

@ -1,292 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import shutil
import torch
from transformers import BioGptConfig, BioGptForCausalLM
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from transformers.utils import WEIGHTS_NAME, logging
logging.set_verbosity_warning()
json_indent = 2
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def _load_meta(self, lines):
return 0
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
f"Duplicate word found when loading Dictionary: '{word}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file."
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
def rewrite_dict_keys(d):
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
keep_keys = "<s> <pad> </s> <unk>".split()
# restore the special tokens
for k in keep_keys:
del d2[f"{k}</w>"]
d2[k] = d[k] # restore
return d2
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
# prep
if not os.path.exists(biogpt_checkpoint_path):
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
print(f"Writing results to {pytorch_dump_folder_path}")
# handle various types of models
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
if not os.path.isfile(checkpoint_file):
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
chkpt = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
args = chkpt["cfg"]["model"]
# dicts
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
if not os.path.isfile(dict_file):
raise ValueError(f"path to the file {dict_file} does not exist!")
src_dict = Dictionary.load(dict_file)
src_vocab = rewrite_dict_keys(src_dict.indices)
src_vocab_size = len(src_vocab)
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
# merges_file (bpecodes)
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
if not os.path.isfile(bpecodes_file):
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
shutil.copyfile(bpecodes_file, merges_file)
# model config
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
model_conf = {
"activation_dropout": args["activation_dropout"],
"architectures": ["BioGptForCausalLM"],
"attention_probs_dropout_prob": args["attention_dropout"],
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_act": args["activation_fn"],
"hidden_dropout_prob": args["dropout"],
"hidden_size": args["decoder_embed_dim"],
"initializer_range": 0.02,
"intermediate_size": args["decoder_ffn_embed_dim"],
"layer_norm_eps": 1e-12,
"layerdrop": args["decoder_layerdrop"],
"max_position_embeddings": args["max_target_positions"],
"model_type": "biogpt",
"num_attention_heads": args["decoder_attention_heads"],
"num_hidden_layers": args["decoder_layers"],
"pad_token_id": 1,
"scale_embedding": not args["no_scale_embedding"],
"tie_word_embeddings": args["share_decoder_input_output_embed"],
"vocab_size": src_vocab_size,
}
# good hparam defaults to start with
print(f"Generating {biogpt_model_config_file}")
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
# tokenizer config
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
tokenizer_conf = {
"bos_token": "<s>",
"eos_token": "</s>",
"model_max_length": 1024,
"pad_token": "<pad>",
"special_tokens_map_file": None,
"tokenizer_class": "BioGptTokenizer",
"unk_token": "<unk>",
}
print(f"Generating {biogpt_tokenizer_config_file}")
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
# model
model_state_dict = chkpt["model"]
# remove unneeded keys
ignore_keys = [
"decoder.version",
]
for k in ignore_keys:
model_state_dict.pop(k, None)
layer_names = list(model_state_dict.keys())
for layer_name in layer_names:
if layer_name.endswith("output_projection.weight"):
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
else:
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
model_new = BioGptForCausalLM(config)
# check that it loads ok
model_new.load_state_dict(model_state_dict)
# save
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
print(f"Generating {pytorch_weights_dump_path}")
torch.save(model_state_dict, pytorch_weights_dump_path)
print("Conversion is done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--biogpt_checkpoint_path",
default=None,
type=str,
required=True,
help=(
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
" bpecodes, etc."
),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -1,177 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BiT checkpoints from the timm library."""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_config(model_name):
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
conv_layer = "std_conv" if "bit" in model_name else False
# note that when using BiT as backbone for ViT-hybrid checkpoints,
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
# config.conv_layer = "std_conv_same"
config = BitConfig(
conv_layer=conv_layer,
num_labels=1000,
id2label=id2label,
label2id=label2id,
)
return config
def rename_key(name):
if "stem.conv" in name:
name = name.replace("stem.conv", "bit.embedder.convolution")
if "blocks" in name:
name = name.replace("blocks", "layers")
if "head.fc" in name:
name = name.replace("head.fc", "classifier.1")
if name.startswith("norm"):
name = "bit." + name
if "bit" not in name and "classifier" not in name:
name = "bit.encoder." + name
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our BiT structure.
"""
# define default BiT configuration
config = get_config(model_name)
# load original model from timm
timm_model = create_model(model_name, pretrained=True)
timm_model.eval()
# load state_dict of original model
state_dict = timm_model.state_dict()
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
# load HuggingFace model
model = BitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# create image processor
transform = create_transform(**resolve_data_config({}, model=timm_model))
timm_transforms = transform.transforms
pillow_resamplings = {
"bilinear": PILImageResampling.BILINEAR,
"bicubic": PILImageResampling.BICUBIC,
"nearest": PILImageResampling.NEAREST,
}
processor = BitImageProcessor(
do_resize=True,
size={"shortest_edge": timm_transforms[0].size},
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
do_center_crop=True,
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
do_normalize=True,
image_mean=timm_transforms[-1].mean.tolist(),
image_std=timm_transforms[-1].std.tolist(),
)
image = prepare_img()
timm_pixel_values = transform(image).unsqueeze(0)
pixel_values = processor(image, return_tensors="pt").pixel_values
# verify pixel values
assert torch.allclose(timm_pixel_values, pixel_values)
# verify logits
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
print("Logits:", logits[0, :3])
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model {model_name} and processor to the hub")
model.push_to_hub(f"ybelkada/{model_name}")
processor.push_to_hub(f"ybelkada/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="resnetv2_50x1_bitm",
type=str,
help="Name of the BiT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub.",
)
args = parser.parse_args()
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,114 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Blenderbot checkpoint."""
import argparse
import torch
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
PATTERNS = [
["attention", "attn"],
["encoder_attention", "encoder_attn"],
["q_lin", "q_proj"],
["k_lin", "k_proj"],
["v_lin", "v_proj"],
["out_lin", "out_proj"],
["norm_embeddings", "layernorm_embedding"],
["position_embeddings", "embed_positions"],
["embeddings", "embed_tokens"],
["ffn.lin", "fc"],
]
def rename_state_dict_key(k):
if k == "embeddings.weight":
return "shared.weight"
for parlai_name, hf_name in PATTERNS:
k = k.replace(parlai_name, hf_name)
if k.startswith("encoder"):
k = k.replace(".attn", ".self_attn")
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "final_layer_norm")
elif k.startswith("decoder"):
k = k.replace("norm1", "self_attn_layer_norm")
k = k.replace("norm2", "encoder_attn_layer_norm")
k = k.replace("norm3", "final_layer_norm")
return k
def rename_layernorm_keys(sd):
keys = [
"model.encoder.layernorm_embedding.weight",
"model.encoder.layernorm_embedding.bias",
"model.decoder.layernorm_embedding.weight",
"model.decoder.layernorm_embedding.bias",
]
for k in keys:
v = sd.pop(k)
new_k = k.replace("layernorm_embedding", "layer_norm")
assert new_k not in sd
sd[new_k] = v
IGNORE_KEYS = ["START"]
@torch.no_grad()
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
model = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
sd = model["model"]
cfg = BlenderbotConfig.from_json_file(config_json_path)
m = BlenderbotForConditionalGeneration(cfg)
valid_keys = m.model.state_dict().keys()
failures = []
mapping = {}
for k, v in sd.items():
if k in IGNORE_KEYS:
continue
new_k = rename_state_dict_key(k)
if new_k not in valid_keys:
failures.append([k, new_k])
else:
mapping[new_k] = v
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
rename_layernorm_keys(sd)
m.model.load_state_dict(mapping, strict=True)
m.half()
m.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
parser.add_argument(
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
)
args = parser.parse_args()
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)

View File

@ -1,191 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import requests
import torch
# git clone https://github.com/salesforce/BLIP.git
from models.blip import blip_decoder
from models.blip_itm import blip_itm
from models.blip_vqa import blip_vqa
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import (
BertTokenizer,
BlipConfig,
BlipForConditionalGeneration,
BlipForImageTextRetrieval,
BlipForQuestionAnswering,
)
def load_demo_image(image_size, device):
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
image = transform(raw_image).unsqueeze(0).to(device)
return image
def rename_key(key):
if "visual_encoder" in key:
key = re.sub("visual_encoder*", "vision_model.encoder", key)
if "blocks" in key:
key = re.sub(r"blocks", "layers", key)
if "attn" in key:
key = re.sub(r"attn", "self_attn", key)
if "norm1" in key:
key = re.sub(r"norm1", "layer_norm1", key)
if "norm2" in key:
key = re.sub(r"norm2", "layer_norm2", key)
if "encoder.norm" in key:
key = re.sub(r"encoder.norm", "post_layernorm", key)
if "encoder.patch_embed.proj" in key:
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
if "encoder.pos_embed" in key:
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
if "encoder.cls_token" in key:
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
if "self_attn" in key:
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
return key
@torch.no_grad()
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = BlipConfig.from_pretrained(config_path)
else:
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = BlipForConditionalGeneration(config).eval()
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
pt_model = pt_model.eval()
modified_state_dict = pt_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_model.load_state_dict(modified_state_dict)
image_size = 384
image = load_demo_image(image_size=image_size, device="cpu")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
input_ids = tokenizer(["a picture of"]).input_ids
out = hf_model.generate(image, input_ids)
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
out = hf_model.generate(image)
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
if pytorch_dump_folder_path is not None:
hf_model.save_pretrained(pytorch_dump_folder_path)
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
model_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
)
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
vqa_model.eval()
modified_state_dict = vqa_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_vqa_model = BlipForQuestionAnswering(config)
hf_vqa_model.load_state_dict(modified_state_dict)
question = ["How many dogs are in this image?"]
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
answer = hf_vqa_model.generate(question_input_ids, image)
print(tokenizer.decode(answer[0]))
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
if pytorch_dump_folder_path is not None:
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
itm_model.eval()
modified_state_dict = itm_model.state_dict()
for key in modified_state_dict.copy():
value = modified_state_dict.pop(key)
renamed_key = rename_key(key)
modified_state_dict[renamed_key] = value
hf_itm_model = BlipForImageTextRetrieval(config)
question = ["A picture of a woman with a dog sitting in a beach"]
question_input_ids = tokenizer(
question,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=35,
).input_ids
hf_itm_model.load_state_dict(modified_state_dict)
hf_itm_model.eval()
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
assert out[0].item() == 0.2110687494277954
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
if pytorch_dump_folder_path is not None:
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)

View File

@ -1,390 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert BLIP-2 checkpoints from the original repository.
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
"""
import argparse
import requests
import torch
# pip3 install salesforce-lavis
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
# to make sure we can compare both original and HF implementation in float32
from lavis.models import load_model_and_preprocess
from PIL import Image
from transformers import (
AutoTokenizer,
BertTokenizer,
Blip2Config,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Processor,
Blip2QFormerConfig,
Blip2VisionConfig,
BlipImageProcessor,
OPTConfig,
T5Config,
set_seed,
)
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
def load_demo_image():
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return image
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, model_name):
rename_keys = []
# fmt: off
# vision encoder
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
for i in range(config.vision_config.num_hidden_layers):
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
# QFormer
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
if "itm" in model_name:
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
rename_keys.append(("text_proj.weight", "text_projection.weight"))
rename_keys.append(("text_proj.bias", "text_projection.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def read_in_q_v_bias(state_dict, config):
for i in range(config.vision_config.num_hidden_layers):
# read in original q and v biases
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
# next, set bias in the state dict
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
def get_blip2_config(model_name, eos_token_id):
image_size = 364 if "coco" in model_name else 224
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
# seems like flan-T5 models don't have bos_token_id properly set?
if "opt-2.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
elif "opt-6.7b" in model_name:
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
elif "t5-xl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "t5-xxl" in model_name:
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
elif "itm" in model_name:
text_config = {}
else:
raise ValueError("Model name not supported")
if "itm" in model_name:
config = Blip2Config(
vision_config=vision_config,
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
)
else:
config = Blip2Config(vision_config=vision_config, text_config=text_config)
return config, image_size
@torch.no_grad()
def convert_blip2_checkpoint(
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
):
"""
Copy/paste/tweak model's weights to Transformers design.
"""
if "opt" in model_name:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
elif "itm" in model_name:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
else:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
if "itm" in model_name:
eos_token_id = None
else:
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
if "itm" in model_name:
hf_model = Blip2ForImageTextRetrieval(config).eval()
else:
hf_model = Blip2ForConditionalGeneration(config).eval()
model_name_to_original = {
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
}
name, type = model_name_to_original[model_name]
# load original model
print("Loading original model...")
original_model, vis_processors, _ = load_model_and_preprocess(
name=name, model_type=type, is_eval=True, device=lavis_device
)
original_model.eval()
print("Done!")
# update state dict keys
state_dict = original_model.state_dict()
rename_keys = create_rename_keys(config, model_name)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
# some keys can be renamed efficiently
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("Qformer.bert"):
key = key.replace("Qformer.bert", "qformer")
if "attention.self" in key:
key = key.replace("self", "attention")
if "opt_proj" in key:
key = key.replace("opt_proj", "language_projection")
if "t5_proj" in key:
key = key.replace("t5_proj", "language_projection")
if key.startswith("opt"):
key = key.replace("opt", "language")
if key.startswith("t5"):
key = key.replace("t5", "language")
state_dict[key] = val
# read in qv biases
read_in_q_v_bias(state_dict, config)
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
assert len(missing_keys) == 0
if "itm" in model_name:
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
else:
assert unexpected_keys == ["qformer.embeddings.position_ids"]
image = load_demo_image()
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
# create processor
image_processor = BlipImageProcessor(
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
)
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
# make sure processor creates exact same pixel values
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
original_model.to(lavis_device)
hf_model.to(hf_model_device)
if "itm" in model_name:
caption = "a large fountain spewing water into the air"
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=True,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
print("Looks ok!")
with torch.no_grad():
original_logits = original_model(
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
)
logits = hf_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=False,
)
assert original_logits.shape == logits.logits_per_image.shape
print("First values of original logits:", original_logits[0, :3])
print("First values of HF logits:", logits.logits_per_image[0, :3])
# assert values
# cast to same type
target_dtype = logits.logits_per_image.dtype
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
print("Looks ok!")
else:
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
with torch.no_grad():
if "opt" in model_name:
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
logits = hf_model(pixel_values, input_ids).logits
else:
original_logits = original_model(
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
).logits
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
logits = hf_model(pixel_values, input_ids, labels=labels).logits
assert original_logits.shape == logits.shape
print("First values of original logits:", original_logits[0, :3, :3])
print("First values of HF logits:", logits[0, :3, :3])
# assert values
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
print("Looks ok!")
print("Generating a caption...")
prompt = "Question: what object is in this image? Answer:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
set_seed(42)
original_outputs = original_model.generate(
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
)
outputs = hf_model.generate(
pixel_values,
input_ids,
do_sample=True,
num_beams=5,
max_length=30,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=1,
)
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
print("Original generation:", original_outputs)
print("HF generation:", output_text)
if pytorch_dump_folder_path is not None:
processor.save_pretrained(pytorch_dump_folder_path)
hf_model.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
processor.push_to_hub(f"nielsr/{model_name}")
hf_model.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices = [
"blip2-opt-2.7b",
"blip2-opt-6.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
"blip2-itm-vit-g",
"blip2-itm-vit-g-coco",
]
parser.add_argument(
"--model_name",
default="blip2-opt-2.7b",
choices=choices,
type=str,
help="Path to hf config.json of model to convert",
)
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model and processor to the hub after converting",
)
# note: this script is tested on 2 GPUs, as models are compared in float32,
# which requires quite some memory. Hence loading both on a
# separate device is the easiest to compare
parser.add_argument(
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
parser.add_argument(
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
args = parser.parse_args()
convert_blip2_checkpoint(
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
)

View File

@ -1,254 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigScience BLOOM checkpoint."""
import argparse
import json
import os
import re
import torch
from transformers import BloomConfig, BloomModel
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging
logging.set_verbosity_info()
WEIGHTS_TO_AVERAGE_ENDSWITH = [
"word_embeddings_layernorm.weight",
"word_embeddings_layernorm.bias",
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
"ln_f.weight",
"ln_f.bias",
]
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
"mlp.dense_4h_to_h.weight",
"self_attention.dense.weight",
]
def layer_name_mapping(key, file):
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
# Handle first and last layers
layer_rename_map = {
"word_embeddings.weight": "word_embeddings.weight",
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
"weight": "ln_f.weight",
"bias": "ln_f.bias",
}
if key in layer_rename_map:
return layer_rename_map[key]
# Handle transformer blocks
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
layer_number -= 3
return f"h.{layer_number}." + key
def get_dtype_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def convert_bloom_checkpoint_to_pytorch(
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
):
# Construct model
if bloom_config_file == "":
config = BloomConfig()
else:
config = BloomConfig.from_json_file(bloom_config_file)
if shard_model:
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
index_dict = {"weight_map": {}, "metadata": {}}
total_size = 0
missing_keys = None
config = BloomConfig()
for j, file in enumerate(file_names):
print(f"Processing file: {file}")
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights across TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
torch.save(
tensors,
os.path.join(
pytorch_dump_folder_path,
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin",
),
)
for key in tensors.keys():
value = tensors[key]
total_size += value.numel() * get_dtype_size(value.dtype)
if key not in index_dict["weight_map"]:
index_dict["weight_map"][key] = (
f"pytorch_model_{str(j + 1).zfill(5)}-of-{str(len(file_names)).zfill(5)}.bin"
)
config = BloomConfig()
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
index_dict["metadata"]["total_size"] = total_size
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
f.write(json_config)
else:
model = BloomModel(config)
file_names = os.listdir(bloom_checkpoint_path)
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
missing_keys = None
for i, file in enumerate(file_names):
tensors = None
for i in range(pretraining_tp):
# load all TP files
f_name = file.replace("model_00", f"model_0{i}")
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
# Rename keys in the transformers names
keys = list(temp.keys())
for key in keys:
temp[layer_name_mapping(key, file)] = temp.pop(key)
if tensors is None:
tensors = temp
else:
for key in tensors.keys():
# We average (sum and then divide) some weights across TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] += temp[key]
else:
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
# We concatenate these weights across TP ranks
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
# Divide by the number of TP the weights we want to average
for key in tensors.keys():
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
tensors[key] = tensors[key] / pretraining_tp
other_keys = model.load_state_dict(tensors, strict=False)
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
if missing_keys is None:
missing_keys = set(other_keys.missing_keys)
else:
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
assert not missing_keys, f"The keys {missing_keys} are missing"
# Save pytorch-model
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
if config.torch_dtype is not None:
model = model.to(config.torch_dtype)
torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--bloom_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the Megatron-LM checkpoint path.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--bloom_config_file",
default="",
type=str,
help=(
"An optional config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--shard_model",
action="store_true",
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
)
parser.add_argument(
"--pretraining_tp",
default=4,
type=int,
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
)
args = parser.parse_args()
convert_bloom_checkpoint_to_pytorch(
args.bloom_checkpoint_path,
args.bloom_config_file,
args.pytorch_dump_folder_path,
args.shard_model,
args.pretraining_tp,
)

View File

@ -1,145 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Bros checkpoints."""
import argparse
import bros # original repo
import torch
from transformers import BrosConfig, BrosModel, BrosProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_configs(model_name):
bros_config = BrosConfig.from_pretrained(model_name)
return bros_config
def remove_ignore_keys_(state_dict):
ignore_keys = [
"embeddings.bbox_sinusoid_emb.inv_freq",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(name):
if name == "embeddings.bbox_projection.weight":
name = "bbox_embeddings.bbox_projection.weight"
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
return name
def convert_state_dict(orig_state_dict, model):
# rename keys
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
orig_state_dict[rename_key(key)] = val
# remove ignore keys
remove_ignore_keys_(orig_state_dict)
return orig_state_dict
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
# load original model
original_model = bros.BrosModel.from_pretrained(model_name).eval()
# load HuggingFace Model
bros_config = get_configs(model_name)
model = BrosModel.from_pretrained(model_name, config=bros_config)
model.eval()
state_dict = original_model.state_dict()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
# verify results
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
bbox = torch.tensor(
[
[
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
]
]
)
processor = BrosProcessor.from_pretrained(model_name)
encoding = processor("His name is Rocco.", return_tensors="pt")
encoding["bbox"] = bbox
original_hidden_states = original_model(**encoding).last_hidden_state
# pixel_values = processor(image, return_tensors="pt").pixel_values
last_hidden_states = model(**encoding).last_hidden_state
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="jinho8345/bros-base-uncased",
required=False,
type=str,
help="Name of the original model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
required=False,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model and processor to the 🤗 hub.",
)
args = parser.parse_args()
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,59 +0,0 @@
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)

View File

@ -1,65 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CANINE checkpoint."""
import argparse
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
# Initialize PyTorch model
config = CanineConfig()
model = CanineModel(config)
model.eval()
print(f"Building PyTorch model from configuration: {config}")
# Load weights from tf checkpoint
load_tf_weights_in_canine(model, config, tf_checkpoint_path)
# Save pytorch-model (weights and configuration)
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# Save tokenizer files
tokenizer = CanineTokenizer()
print(f"Save tokenizer files to {pytorch_dump_path}")
tokenizer.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
)
parser.add_argument(
"--pytorch_dump_path",
default=None,
type=str,
required=True,
help="Path to a folder where the PyTorch model will be placed.",
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)

View File

@ -1,478 +0,0 @@
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import requests
import torch
import yaml
from accelerate import init_empty_weights
from PIL import Image
from transformers import (
ChameleonConfig,
ChameleonForConditionalGeneration,
ChameleonImageProcessor,
ChameleonProcessor,
)
try:
from transformers import LlamaTokenizerFast
except ImportError:
raise ValueError(
"Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
"Update your `tokenizers` library and re-run the tokenizer conversion."
)
"""
Sample usage:
```
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
--input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
NUM_SHARDS = {
"7B": 1,
"30B": 4,
}
VOCAB_SIZE = 65536
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
os.makedirs(model_path, exist_ok=True)
input_model_path = os.path.join(input_base_path, "models", model_size.lower())
params_path = os.path.join(input_model_path, "params.json")
consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
params = read_json(params_path)
if os.path.isfile(consolidate_params_path):
params = {**params, **read_json(consolidate_params_path)}
num_shards = NUM_SHARDS[model_size]
model_parallel_size = params["model_parallel_size"]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
swin_norm = params["swin_norm"]
if base > 10000.0:
max_position_embeddings = 16384
else:
# Depending on the Chameleon version, the default max_position_embeddings has different values.
if chameleon_version == 1:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {chameleon_version} of chameleon is not supported yet. "
"Current supported versions of chameleon are [1]."
)
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = None
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
possible_path = os.path.join(input_model_path, possible_name)
if os.path.exists(possible_path):
loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
break
assert loaded is not None
else:
# Sharded
loaded = [
torch.load(
os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
)
for i in range(num_shards)
]
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
# Load weights to the state dict
state_dict = {}
for layer_i in range(n_layers):
if num_shards == 1:
# Unsharded
state_dict.update(
{
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"],
n_heads=num_key_value_heads,
dim1=key_value_dim,
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
f"layers.{layer_i}.ffn_norm.weight"
],
}
)
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(n_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(n_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(num_key_value_heads, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
.view(dims_per_head // 2, 2)
.t()
.reshape(1, -1)
.repeat_interleave(num_key_value_heads, 0)
)
else:
# Sharded
state_dict.update(
{
f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
[l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
).mean(dim=0),
f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
[l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
).mean(dim=0),
}
)
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
n_heads=num_key_value_heads,
dim1=key_value_dim,
)
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(n_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(n_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(num_key_value_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
.view(num_shards, dims_per_head // 2, 2)
.transpose(1, 2)
.reshape(num_shards, -1)
.repeat_interleave(num_key_value_heads // num_shards, 0)
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
if num_shards == 1:
# Unsharded
state_dict.update(
{
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
)
else:
state_dict.update(
{
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
)
# Load VQGAN weights
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
for k, v in vqgan_state_dict.items():
if "decoder" in k:
continue # we dont do image generation yet
state_dict[f"model.vqmodel.{k}"] = v
# Write configs
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
tokenizer_config = json.load(tokenizer_file)
vocabulary_map = tokenizer_config["model"]["vocab"]
vocabulary_map["<image>"] = vocabulary_map[
"<reserved08707>"
] # use a reserved token instead of adding a new one
del vocabulary_map["<reserved08707>"]
for token in tokenizer_config["added_tokens"]:
if token["content"] == "<reserved08707>":
token["content"] = "<image>"
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
json.dump(tokenizer_config, f) # save the new file to init tokenizer later
vq_keys_to_replace = [
("ch", "base_channels"),
("out_ch", "out_channels"),
("n_embed", "num_embeddings"),
("ch_mult", "channel_multiplier"),
("double_z", "double_latent"),
("z_channels", "latent_channels"),
]
with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
vq_config.update(**vq_config["ddconfig"])
for old, new in vq_keys_to_replace:
vq_config[new] = vq_config[old]
del vq_config["ddconfig"]
del vq_config["ckpt_path"]
del vq_config["lossconfig"]
config = ChameleonConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=VOCAB_SIZE,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
model_parallel_size=model_parallel_size,
swin_norm=swin_norm,
vq_config=vq_config,
vocabulary_map=vocabulary_map,
)
with init_empty_weights():
model = ChameleonForConditionalGeneration(config)
model.load_state_dict(state_dict, assign=True, strict=False)
model.save_pretrained(model_path, safe_serialization=True)
# Load and save the processor
tokenizer = LlamaTokenizerFast(
tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
)
tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
tokenizer.pad_token_id = 1 # assign <pad> to special pad_token
image_processor = ChameleonImageProcessor()
processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
del vqgan_state_dict
gc.collect()
# Short inference on a few examples to check if generation makes sense
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
print("Loading the checkpoint in a Chameleon model...")
print("*" * 100)
model = ChameleonForConditionalGeneration.from_pretrained(
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained(model_path)
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
image = Image.open(
requests.get(
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
).raw
)
inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
length = inputs.input_ids.shape[1]
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
print(f"Generation for single-image: {generated_text}")
print("*" * 100)
# Multi-image example
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
length = inputs.input_ids.shape[1]
out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
print(f"Generation for multi-image: {generated_text}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Chameleon weights",
)
parser.add_argument(
"--model_size",
choices=["7B", "30B"],
help=""
" models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, check out the original repo: https://github.com/facebookresearch/chameleon",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model",
)
parser.add_argument(
"--test_inference",
action="store_true",
help="Whether to load the model for generation to test it's converted correctly.",
)
# Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--chameleon_version",
choices=[1],
default=1,
type=int,
help="Version of the Chameleon model to convert",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
chameleon_version=args.chameleon_version,
)
if __name__ == "__main__":
main()

View File

@ -1,134 +0,0 @@
# coding=utf-8
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from transformers import ChineseCLIPConfig, ChineseCLIPModel
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
hf_attn_layer.q_proj.weight.data = q_proj
hf_attn_layer.q_proj.bias.data = q_proj_bias
hf_attn_layer.k_proj.weight.data = k_proj
hf_attn_layer.k_proj.bias.data = k_proj_bias
hf_attn_layer.v_proj.weight.data = v_proj
hf_attn_layer.v_proj.bias.data = v_proj_bias
hf_attn_layer.out_proj.weight.data = out_proj_weights
hf_attn_layer.out_proj.bias.data = out_proj_bias
def copy_mlp(hf_mlp, pt_weights, prefix):
copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
def copy_linear(hf_linear, pt_weights, prefix):
hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
def copy_layer(hf_layer, pt_weights, prefix):
# copy layer norms
copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
# copy MLP
copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
# copy attn
copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
def copy_layers(hf_layers, pt_weights, prefix):
for layer_id, hf_layer in enumerate(hf_layers):
copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
def copy_text_model_and_projection(hf_model, pt_weights):
# copy projection
hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
# copy text encoder
for name, param in hf_model.text_model.named_parameters():
param.data = pt_weights[f"bert.{name}"].data
def copy_vision_model_and_projection(hf_model, pt_weights):
# copy projection
hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
# copy layer norms
copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
# copy embeddings
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
# copy encoder
copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
@torch.no_grad()
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
config = ChineseCLIPConfig.from_pretrained(config_path)
hf_model = ChineseCLIPModel(config).eval()
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
copy_text_model_and_projection(hf_model, pt_weights)
copy_vision_model_and_projection(hf_model, pt_weights)
hf_model.logit_scale.data = pt_weights["logit_scale"].data
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output folder storing converted hf PyTorch model.",
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
)
parser.add_argument(
"--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
)
args = parser.parse_args()
convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
print("The conversion is finished!")

View File

@ -1,133 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
from laion_clap import CLAP_Module
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
KEYS_TO_MODIFY_MAPPING = {
"text_branch": "text_model",
"audio_branch": "audio_model.audio_encoder",
"attn": "attention.self",
"self.proj": "output.dense",
"attention.self_mask": "attn_mask",
"mlp.fc1": "intermediate.dense",
"mlp.fc2": "output.dense",
"norm1": "layernorm_before",
"norm2": "layernorm_after",
"bn0": "batch_norm",
}
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
def init_clap(checkpoint_path, model_type, enable_fusion=False):
model = CLAP_Module(
amodel=model_type,
enable_fusion=enable_fusion,
)
model.load_ckpt(checkpoint_path)
return model
def get_config_from_original(clap_model):
audio_config = {
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
"depths": clap_model.model.audio_branch.depths,
"hidden_size": clap_model.model.audio_projection[0].in_features,
}
text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
return ClapConfig(audio_config=audio_config, text_config=text_config)
def rename_state_dict(state_dict):
model_state_dict = {}
sequential_layers_pattern = r".*sequential.(\d+).*"
text_projection_pattern = r".*_projection.(\d+).*"
for key, value in state_dict.items():
# check if any key needs to be modified
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if re.match(sequential_layers_pattern, key):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
# Because in CLAP they use `nn.Sequential`...
transformers_projection_layer = 1 if projecton_layer == 0 else 2
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
if "audio" and "qkv" in key:
# split qkv into query key and value
mixed_qkv = value
qkv_dim = mixed_qkv.size(0) // 3
query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
model_state_dict[key.replace("qkv", "query")] = query_layer
model_state_dict[key.replace("qkv", "key")] = key_layer
model_state_dict[key.replace("qkv", "value")] = value_layer
else:
model_state_dict[key] = value
return model_state_dict
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
clap_model.eval()
state_dict = clap_model.model.state_dict()
state_dict = rename_state_dict(state_dict)
transformers_config = get_config_from_original(clap_model)
transformers_config.audio_config.enable_fusion = enable_fusion
model = ClapModel(transformers_config)
# ignore the spectrogram embedding layer
model.load_state_dict(state_dict, strict=False)
model.save_pretrained(pytorch_dump_folder_path)
transformers_config.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
args = parser.parse_args()
convert_clap_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
)

View File

@ -1,156 +0,0 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from clip import load
from transformers import CLIPConfig, CLIPModel
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
out_proj_weights = pt_attn_layer.out_proj.weight
out_proj_bias = pt_attn_layer.out_proj.bias
hf_attn_layer.q_proj.weight.data = q_proj
hf_attn_layer.q_proj.bias.data = q_proj_bias
hf_attn_layer.k_proj.weight.data = k_proj
hf_attn_layer.k_proj.bias.data = k_proj_bias
hf_attn_layer.v_proj.weight.data = v_proj
hf_attn_layer.v_proj.bias.data = v_proj_bias
hf_attn_layer.out_proj.weight = out_proj_weights
hf_attn_layer.out_proj.bias = out_proj_bias
def copy_mlp(hf_mlp, pt_mlp):
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
def copy_linear(hf_linear, pt_linear):
hf_linear.weight = pt_linear.weight
hf_linear.bias = pt_linear.bias
def copy_layer(hf_layer, pt_layer):
# copy layer norms
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
# copy MLP
copy_mlp(hf_layer.mlp, pt_layer.mlp)
# copy attn
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
def copy_layers(hf_layers, pt_layers):
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
copy_layer(hf_layer, pt_layer)
def copy_encoder(hf_encoder, pt_model):
# copy embeds
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
# copy layer norm
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
# copy hidden layers
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
def copy_text_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
# copy text encoder
copy_encoder(hf_model.text_model, pt_model)
def copy_vison_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
# copy layer norms
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
# copy embeds
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
# copy encoder
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
@torch.no_grad()
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = CLIPConfig.from_pretrained(config_path)
else:
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = CLIPModel(config).eval()
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
pt_model = pt_model.eval()
copy_text_model_and_projection(hf_model, pt_model)
copy_vison_model_and_projection(hf_model, pt_model)
hf_model.logit_scale = pt_model.logit_scale
# Use `eos_token` so the example is more meaningful
input_ids = torch.tensor(
[
[config.text_config.bos_token_id]
+ list(range(3, 77))
+ [config.text_config.eos_token_id]
+ [config.text_config.pad_token_id]
]
)
pixel_values = torch.randn(1, 3, 224, 224)
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
hf_logits_per_image = hf_outputs.logits_per_image
hf_logits_per_text = hf_outputs.logits_per_text
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)

View File

@ -1,264 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
CLIPSegConfig,
CLIPSegForImageSegmentation,
CLIPSegProcessor,
CLIPSegTextConfig,
CLIPSegVisionConfig,
CLIPTokenizer,
ViTImageProcessor,
)
def get_clipseg_config(model_name):
text_config = CLIPSegTextConfig()
vision_config = CLIPSegVisionConfig(patch_size=16)
use_complex_transposed_convolution = True if "refined" in model_name else False
reduce_dim = 16 if "rd16" in model_name else 64
config = CLIPSegConfig.from_text_vision_configs(
text_config,
vision_config,
use_complex_transposed_convolution=use_complex_transposed_convolution,
reduce_dim=reduce_dim,
)
return config
def rename_key(name):
# update prefixes
if "clip_model" in name:
name = name.replace("clip_model", "clip")
if "transformer" in name:
if "visual" in name:
name = name.replace("visual.transformer", "vision_model")
else:
name = name.replace("transformer", "text_model")
if "resblocks" in name:
name = name.replace("resblocks", "encoder.layers")
if "ln_1" in name:
name = name.replace("ln_1", "layer_norm1")
if "ln_2" in name:
name = name.replace("ln_2", "layer_norm2")
if "c_fc" in name:
name = name.replace("c_fc", "fc1")
if "c_proj" in name:
name = name.replace("c_proj", "fc2")
if "attn" in name and "self" not in name:
name = name.replace("attn", "self_attn")
# text encoder
if "token_embedding" in name:
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
if "positional_embedding" in name and "visual" not in name:
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
if "ln_final" in name:
name = name.replace("ln_final", "text_model.final_layer_norm")
# vision encoder
if "visual.class_embedding" in name:
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
if "visual.conv1" in name:
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
if "visual.positional_embedding" in name:
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
if "visual.ln_pre" in name:
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
if "visual.ln_post" in name:
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
# projection layers
if "visual.proj" in name:
name = name.replace("visual.proj", "visual_projection.weight")
if "text_projection" in name:
name = name.replace("text_projection", "text_projection.weight")
# decoder
if "trans_conv" in name:
name = name.replace("trans_conv", "transposed_convolution")
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
name = "decoder." + name
if "blocks" in name:
name = name.replace("blocks", "decoder.layers")
if "linear1" in name:
name = name.replace("linear1", "mlp.fc1")
if "linear2" in name:
name = name.replace("linear2", "mlp.fc2")
if "norm1" in name and "layer_" not in name:
name = name.replace("norm1", "layer_norm1")
if "norm2" in name and "layer_" not in name:
name = name.replace("norm2", "layer_norm2")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if key.startswith("clip_model") and "attn.in_proj" in key:
key_split = key.split(".")
if "visual" in key:
layer_num = int(key_split[4])
dim = config.vision_config.hidden_size
prefix = "vision_model"
else:
layer_num = int(key_split[3])
dim = config.text_config.hidden_size
prefix = "text_model"
if "weight" in key:
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
dim : dim * 2, :
]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
else:
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
elif "self_attn" in key and "out_proj" not in key:
key_split = key.split(".")
layer_num = int(key_split[1])
dim = config.reduce_dim
if "weight" in key:
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
else:
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
else:
new_name = rename_key(key)
if "visual_projection" in new_name or "text_projection" in new_name:
val = val.T
orig_state_dict[new_name] = val
return orig_state_dict
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
config = get_clipseg_config(model_name)
model = CLIPSegForImageSegmentation(config)
model.eval()
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# remove some keys
for key in state_dict.copy().keys():
if key.startswith("model"):
state_dict.pop(key, None)
# rename some keys
state_dict = convert_state_dict(state_dict, config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
raise ValueError(f"Missing keys that are not expected: {missing_keys}")
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
raise ValueError(f"Unexpected keys: {unexpected_keys}")
image_processor = ViTImageProcessor(size=352)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
image = prepare_img()
text = ["a glass", "something to fill", "wood", "a jar"]
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# verify values
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
if model_name == "clipseg-rd64-refined":
expected_masks_slice = torch.tensor(
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
)
elif model_name == "clipseg-rd64":
expected_masks_slice = torch.tensor(
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
)
elif model_name == "clipseg-rd16":
expected_masks_slice = torch.tensor(
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
)
else:
raise ValueError(f"Model name {model_name} not supported.")
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and processor for {model_name} to the hub")
model.push_to_hub(f"CIDAS/{model_name}")
processor.push_to_hub(f"CIDAS/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="clipseg-rd64",
type=str,
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
help=(
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
" reduce dimension)"
),
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
type=str,
help=(
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
" the decoder weights."
),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -1,234 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Weights conversion script for CLVP
"""
import argparse
import os
import torch
from huggingface_hub import hf_hub_download
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
_MODELS = {
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
}
dim = 1024
sub_dim = dim // 16
CLVP_ENCODERS_MAPPING = {
"text_transformer.transformer.attn_layers": "text_encoder_model",
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
"to_text_latent": "text_encoder_model.projection",
"to_speech_latent": "speech_encoder_model.projection",
"text_emb": "text_encoder_model.token_embedding",
"speech_emb": "speech_encoder_model.token_embedding",
"1.wrap.net.0": "mlp.fc1",
"1.wrap.net.3": "mlp.fc2",
"1.wrap": "self_attn",
"to_out": "out_proj",
"to_q": "q_proj",
"to_k": "k_proj",
"to_v": "v_proj",
"temperature": "logit_scale",
}
CLVP_DECODER_MAPPING = {
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
"mel_attn_blocks": "group_norms",
".norm.weight": ".weight",
".norm.bias": ".bias",
"text_embedding": "conditioning_encoder.text_token_embedding",
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
"final_norm": "speech_decoder_model.final_norm",
"mel_head": "speech_decoder_model.lm_head",
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
"gpt.h": "speech_decoder_model.model.decoder.layers",
"ln_1": "input_layernorm",
"ln_2": "post_attention_layernorm",
}
def update_index(present_index):
if present_index % 2 == 0:
return int(present_index / 2)
else:
return int((present_index - 1) / 2)
def convert_encoder_weights(original_weights):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
for original_key in original_weights_keys:
updated_key = original_key
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
if "0.0.g" in updated_key:
present_index = updated_key.split(".")[4]
if int(present_index) % 2 == 0:
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
else:
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
if "transformer.attn_layers.layers" in updated_key:
present_index = updated_key.split(".")[4]
updated_index = update_index(int(present_index))
updated_key = updated_key.replace(
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
)
for k, v in CLVP_ENCODERS_MAPPING.items():
if k in updated_key:
updated_key = updated_key.replace(k, v)
converted_weights[updated_key] = original_weights.pop(original_key)
return converted_weights
def convert_decoder_weights(original_weights):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
for original_key in original_weights_keys:
updated_key = original_key
if len(updated_key.split(".")) > 3:
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
# for decoder attention
if "attn.c_attn" in updated_key:
if attr == "weight":
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
else:
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
continue
if "attn.c_proj" in updated_key:
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
original_weights[updated_key].squeeze(-1).T
)
continue
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
original_weights.pop(updated_key)
continue
# conditional encoder attention
if "qkv" in updated_key:
if attr == "weight":
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
else:
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
indices = torch.arange(dim)
index1, index2, index3 = (
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
[slice1[index1], slice2[index3], slice3[index2]],
axis=0,
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
[slice1[index2], slice2[index1], slice3[index3]],
axis=0,
)
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
[slice1[index3], slice2[index2], slice3[index1]],
axis=0,
)
continue
if "proj_out" in updated_key:
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
updated_key
].squeeze(-1)
continue
for k, v in CLVP_DECODER_MAPPING.items():
if k in updated_key:
updated_key = updated_key.replace(k, v)
converted_weights[updated_key] = original_weights.pop(original_key)
return converted_weights
def _download(url: str, root: str):
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
hf_hub_download(
repo_id=repo_id,
filename=filename,
force_filename=root,
local_dir_use_symlinks=False,
)
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
converted_checkpoint = {}
for each_model_name, each_model_url in _MODELS.items():
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
if not os.path.exists(each_model_path):
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
_download(url=each_model_url, root=each_model_path)
if each_model_name == "clvp":
clvp_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
else:
decoder_checkpoint = torch.load(each_model_path, map_location="cpu", weights_only=True)
# Converting the weights
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
model = ClvpModelForConditionalGeneration(config)
model.load_state_dict(converted_checkpoint, strict=True)
model.save_pretrained(pytorch_dump_folder_path)
print(f"Model saved at {pytorch_dump_folder_path}!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument(
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model. (Please enter full path)",
)
args = parser.parse_args()
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)

View File

@ -1,214 +0,0 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColPali weights from the original repository to the HF model format.
Original repository: https://github.com/illuin-tech/colpali.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf-internal \
--push_to_hub
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.3-merged \
--revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.3-hf \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colpali import ColPaliForRetrieval
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
new_state_dict = {}
for key, value in state_dict.items():
new_key = key
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
if key.startswith("model."):
new_key = key.replace("model.", "vlm.", 1)
new_state_dict[new_key] = value
return new_state_dict
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
"model.language_model.model.embed_tokens.weight"
].clone()
return original_state_dict
@torch.no_grad()
def convert_colpali_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColPaliConfig(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colpali"
config.is_composition = False
# Load the untrained model
model = ColPaliForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# Tie the weights (following ColPali's `__init__`` step)
if model.vlm.language_model._tied_weights_keys is not None:
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
# Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColPali model to the HF model format.
Example usage:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colpali_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@ -1,212 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColQwen2 weights from the original repository to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Original repository: https://github.com/illuin-tech/colqwen2.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colqwen2 import ColQwen2ForRetrieval
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
return original_state_dict
def rename_state_dict_keys(state_dict: dict[str, Any]) -> dict[str, Any]:
new_state_dict: dict[str, Any] = {}
for key, value in state_dict.items():
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
else:
# The original ColQwen2 inherits from Qwen2VL, so we simply need to add the `vlm.` prefix
# to all remaining keys.
if key.startswith("model."):
key = key.replace("model.", "model.language_model.")
if key.startswith("visual."):
key = key.replace("visual.", "model.visual.")
new_key = "vlm." + key
new_state_dict[new_key] = value
return new_state_dict
@torch.no_grad()
def convert_colqwen2_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColQwen2Config(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colqwen2"
config.is_composition = False
# Load the untrained model
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# # Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColQwen2 model to the HF model format.
Don't forget to manually upload the processor-related files to the HF model repository
after running this script.
Example usage:
```bash
python src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py \
--model_id vidore/colqwen2-v1.0-merged \
--revision eeccbae1d44bdcb0c83b1788127a2b2cad7d718e \
--original_vlm_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--output_dir vidore/colqwen2-v1.0-hf-internal \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colqwen2_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@ -1,324 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Conditional DETR checkpoints."""
import argparse
import json
from collections import OrderedDict
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
ConditionalDetrConfig,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
for i in range(6):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
# q, k, v projections in self/cross-attention in decoder for conditional DETR
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
)
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
)
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
)
rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
)
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
# for conditional DETR, also convert reference point head and query scale MLP
rename_keys.extend(
[
("input_proj.weight", "input_projection.weight"),
("input_proj.bias", "input_projection.bias"),
("query_embed.weight", "query_position_embeddings.weight"),
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
("class_embed.weight", "class_labels_classifier.weight"),
("class_embed.bias", "class_labels_classifier.bias"),
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
]
)
def rename_key(state_dict, old, new):
val = state_dict.pop(old)
state_dict[new] = val
def rename_backbone_keys(state_dict):
new_state_dict = OrderedDict()
for key, value in state_dict.items():
if "backbone.0.body" in key:
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
def read_in_q_k_v(state_dict, is_panoptic=False):
prefix = ""
if is_panoptic:
prefix = "conditional_detr."
# first: transformer encoder
for i in range(6):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
"""
# load default config
config = ConditionalDetrConfig()
# set backbone and dilation attributes
if "resnet101" in model_name:
config.backbone = "resnet101"
if "dc5" in model_name:
config.dilation = True
is_panoptic = "panoptic" in model_name
if is_panoptic:
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
# load image processor
format = "coco_panoptic" if is_panoptic else "coco_detection"
image_processor = ConditionalDetrImageProcessor(format=format)
# prepare image
img = prepare_img()
encoding = image_processor(images=img, return_tensors="pt")
pixel_values = encoding["pixel_values"]
logger.info(f"Converting model {model_name}...")
# load original model from torch hub
conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
state_dict = conditional_detr.state_dict()
# rename keys
for src, dest in rename_keys:
if is_panoptic:
src = "conditional_detr." + src
rename_key(state_dict, src, dest)
state_dict = rename_backbone_keys(state_dict)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
prefix = "conditional_detr.model." if is_panoptic else "model."
for key in state_dict.copy().keys():
if is_panoptic:
if (
key.startswith("conditional_detr")
and not key.startswith("class_labels_classifier")
and not key.startswith("bbox_predictor")
):
val = state_dict.pop(key)
state_dict["conditional_detr.model" + key[4:]] = val
elif "class_labels_classifier" in key or "bbox_predictor" in key:
val = state_dict.pop(key)
state_dict["conditional_detr." + key] = val
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
continue
else:
val = state_dict.pop(key)
state_dict[prefix + key] = val
else:
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
val = state_dict.pop(key)
state_dict[prefix + key] = val
# finally, create HuggingFace model and load state dict
model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
# verify our conversion
original_outputs = conditional_detr(pixel_values)
outputs = model(pixel_values)
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
if is_panoptic:
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
# Save model and image processor
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="conditional_detr_resnet50",
type=str,
help="Name of the CONDITIONAL_DETR model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)

View File

@ -1,57 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ConvBERT checkpoint."""
import argparse
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
conf = ConvBertConfig.from_json_file(convbert_config_file)
model = ConvBertModel(conf)
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
model.save_pretrained(pytorch_dump_path)
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
tf_model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--convbert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained ConvBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)

View File

@ -1,242 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ConvNext checkpoints from the original repository.
URL: https://github.com/facebookresearch/ConvNeXt"""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_convnext_config(checkpoint_url):
config = ConvNextConfig()
if "tiny" in checkpoint_url:
depths = [3, 3, 9, 3]
hidden_sizes = [96, 192, 384, 768]
if "small" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [96, 192, 384, 768]
if "base" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [128, 256, 512, 1024]
if "large" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [192, 384, 768, 1536]
if "xlarge" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [256, 512, 1024, 2048]
if "1k" in checkpoint_url:
num_labels = 1000
filename = "imagenet-1k-id2label.json"
expected_shape = (1, 1000)
else:
num_labels = 21841
filename = "imagenet-22k-id2label.json"
expected_shape = (1, 21841)
repo_id = "huggingface/label-files"
config.num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
if "1k" not in checkpoint_url:
# this dataset contains 21843 labels but the model only has 21841
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
del id2label[9205]
del id2label[15027]
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.hidden_sizes = hidden_sizes
config.depths = depths
return config, expected_shape
def rename_key(name):
if "downsample_layers.0.0" in name:
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
if "downsample_layers.0.1" in name:
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
if "downsample_layers.1.0" in name:
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
if "downsample_layers.1.1" in name:
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
if "downsample_layers.2.0" in name:
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
if "downsample_layers.2.1" in name:
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
if "downsample_layers.3.0" in name:
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
if "downsample_layers.3.1" in name:
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
if "stages" in name and "downsampling_layer" not in name:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
if "stages" in name:
name = name.replace("stages", "encoder.stages")
if "norm" in name:
name = name.replace("norm", "layernorm")
if "gamma" in name:
name = name.replace("gamma", "layer_scale_parameter")
if "head" in name:
name = name.replace("head", "classifier")
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our ConvNext structure.
"""
# define ConvNext configuration based on URL
config, expected_shape = get_convnext_config(checkpoint_url)
# load original state_dict from URL
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
# rename keys
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val
# add prefix to all keys expect classifier head
for key in state_dict.copy().keys():
val = state_dict.pop(key)
if not key.startswith("classifier"):
key = "convnext." + key
state_dict[key] = val
# load HuggingFace model
model = ConvNextForImageClassification(config)
model.load_state_dict(state_dict)
model.eval()
# Check outputs on an image, prepared by ConvNextImageProcessor
size = 224 if "224" in checkpoint_url else 384
image_processor = ConvNextImageProcessor(size=size)
pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values
logits = model(pixel_values).logits
# note: the logits below were obtained without center cropping
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth":
expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth":
expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth":
expected_logits = torch.tensor([0.4525, 0.7539, 0.0308])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth":
expected_logits = torch.tensor([0.3561, 0.6350, -0.0384])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth":
expected_logits = torch.tensor([0.4174, -0.0989, 0.1489])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth":
expected_logits = torch.tensor([0.2513, -0.1349, -0.1613])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth":
expected_logits = torch.tensor([1.2980, 0.3631, -0.1198])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth":
expected_logits = torch.tensor([1.2963, 0.1227, 0.1723])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth":
expected_logits = torch.tensor([1.7956, 0.8390, 0.2820])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth":
expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth":
expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth":
expected_logits = torch.tensor([0.2681, 0.2365, 0.6246])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth":
expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth":
expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth":
expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444])
else:
raise ValueError(f"Unknown URL: {checkpoint_url}")
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3)
assert logits.shape == expected_shape
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
print("Pushing model to the hub...")
model_name = "convnext"
if "tiny" in checkpoint_url:
model_name += "-tiny"
elif "small" in checkpoint_url:
model_name += "-small"
elif "base" in checkpoint_url:
model_name += "-base"
elif "xlarge" in checkpoint_url:
model_name += "-xlarge"
elif "large" in checkpoint_url:
model_name += "-large"
if "224" in checkpoint_url:
model_name += "-224"
elif "384" in checkpoint_url:
model_name += "-384"
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
model_name += "-22k"
if "22k" in checkpoint_url and "1k" in checkpoint_url:
model_name += "-22k-1k"
model.push_to_hub(
repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
organization="nielsr",
commit_message="Add model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
type=str,
help="URL of the original ConvNeXT checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model directory.",
)
args = parser.parse_args()
convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

View File

@ -1,286 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ConvNeXTV2 checkpoints from the original repository.
URL: https://github.com/facebookresearch/ConvNeXt"""
import argparse
import json
import os
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_convnextv2_config(checkpoint_url):
config = ConvNextV2Config()
if "atto" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [40, 80, 160, 320]
if "femto" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [48, 96, 192, 384]
if "pico" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [64, 128, 256, 512]
if "nano" in checkpoint_url:
depths = [2, 2, 8, 2]
hidden_sizes = [80, 160, 320, 640]
if "tiny" in checkpoint_url:
depths = [3, 3, 9, 3]
hidden_sizes = [96, 192, 384, 768]
if "base" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [128, 256, 512, 1024]
if "large" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [192, 384, 768, 1536]
if "huge" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [352, 704, 1408, 2816]
num_labels = 1000
filename = "imagenet-1k-id2label.json"
expected_shape = (1, 1000)
repo_id = "huggingface/label-files"
config.num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.hidden_sizes = hidden_sizes
config.depths = depths
return config, expected_shape
def rename_key(name):
if "downsample_layers.0.0" in name:
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
if "downsample_layers.0.1" in name:
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
if "downsample_layers.1.0" in name:
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
if "downsample_layers.1.1" in name:
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
if "downsample_layers.2.0" in name:
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
if "downsample_layers.2.1" in name:
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
if "downsample_layers.3.0" in name:
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
if "downsample_layers.3.1" in name:
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
if "stages" in name and "downsampling_layer" not in name:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
if "gamma" in name:
name = name.replace("gamma", "weight")
if "beta" in name:
name = name.replace("beta", "bias")
if "stages" in name:
name = name.replace("stages", "encoder.stages")
if "norm" in name:
name = name.replace("norm", "layernorm")
if "head" in name:
name = name.replace("head", "classifier")
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def convert_preprocessor(checkpoint_url):
if "224" in checkpoint_url:
size = 224
crop_pct = 224 / 256
elif "384" in checkpoint_url:
size = 384
crop_pct = None
else:
size = 512
crop_pct = None
return ConvNextImageProcessor(
size=size,
crop_pct=crop_pct,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
resample=PILImageResampling.BICUBIC,
)
@torch.no_grad()
def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
"""
Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
"""
print("Downloading original model from checkpoint...")
# define ConvNeXTV2 configuration based on URL
config, expected_shape = get_convnextv2_config(checkpoint_url)
# load original state_dict from URL
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
print("Converting model parameters...")
# rename keys
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val
# add prefix to all keys expect classifier head
for key in state_dict.copy().keys():
val = state_dict.pop(key)
if not key.startswith("classifier"):
key = "convnextv2." + key
state_dict[key] = val
# load HuggingFace model
model = ConvNextV2ForImageClassification(config)
model.load_state_dict(state_dict)
model.eval()
# Check outputs on an image, prepared by ConvNextImageProcessor
preprocessor = convert_preprocessor(checkpoint_url)
inputs = preprocessor(images=prepare_img(), return_tensors="pt")
logits = model(**inputs).logits
# note: the logits below were obtained without center cropping
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
else:
raise ValueError(f"Unknown URL: {checkpoint_url}")
assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
assert logits.shape == expected_shape
print("Model outputs match the original results!")
if save_model:
print("Saving model to local...")
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
model.save_pretrained(pytorch_dump_folder_path)
preprocessor.save_pretrained(pytorch_dump_folder_path)
model_name = "convnextv2"
if "atto" in checkpoint_url:
model_name += "-atto"
if "femto" in checkpoint_url:
model_name += "-femto"
if "pico" in checkpoint_url:
model_name += "-pico"
if "nano" in checkpoint_url:
model_name += "-nano"
elif "tiny" in checkpoint_url:
model_name += "-tiny"
elif "base" in checkpoint_url:
model_name += "-base"
elif "large" in checkpoint_url:
model_name += "-large"
elif "huge" in checkpoint_url:
model_name += "-huge"
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
model_name += "-22k"
elif "22k" in checkpoint_url and "1k" in checkpoint_url:
model_name += "-22k-1k"
elif "1k" in checkpoint_url:
model_name += "-1k"
if "224" in checkpoint_url:
model_name += "-224"
elif "384" in checkpoint_url:
model_name += "-384"
elif "512" in checkpoint_url:
model_name += "-512"
if push_to_hub:
print(f"Pushing {model_name} to the hub...")
model.push_to_hub(model_name)
preprocessor.push_to_hub(model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt",
type=str,
help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
args = parser.parse_args()
convert_convnextv2_checkpoint(
args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
)

View File

@ -1,339 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import torch
from tokenizers.processors import TemplateProcessing
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
CsmConfig,
CsmDepthDecoderConfig,
CsmForConditionalGeneration,
CsmProcessor,
MimiModel,
)
from transformers.utils.hub import cached_file
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1",
r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1",
r"attn": r"self_attn",
r"output_proj": r"o_proj",
r"w1": r"gate_proj",
r"w2": r"down_proj",
r"w3": r"up_proj",
r"text_embeddings": r"embed_text_tokens",
r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens",
r"codebook0_head": r"lm_head",
r"audio_head": r"depth_decoder.codebooks_head.weight",
r"projection": r"depth_decoder.model.inputs_embeds_projector",
r"sa_norm.scale": r"input_layernorm.weight",
r"mlp_norm.scale": r"post_attention_layernorm.weight",
r"decoder.norm.scale": r"depth_decoder.model.norm.weight",
r"backbone.norm.scale": r"backbone_model.norm.weight",
}
# fmt: on
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
input_tensor = input_tensor.reshape(dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
return input_tensor
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
key = re.sub(pattern, replacement, key)
return key
def write_model(
input_path_or_repo,
model_name,
codec_model_path_or_repo,
output_dir,
safe_serialization=True,
):
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
codec_model = MimiModel.from_pretrained(codec_model_path_or_repo)
codec_model.config._attn_implementation_autoset = False
# prepare rope scaling args: the model uses originally
# 1 - for the depth decoder
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
# 2 - for the backbone
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
#
# Yet we want to use max_position_embeddings=32, resp. 2048
# This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings
# Therefore, we convert values to equivalent ones
depth_decoder_config = CsmDepthDecoderConfig(
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.0078125,
"low_freq_factor": 0.001953125,
"original_max_position_embeddings": 16,
"rope_type": "llama3",
},
)
config = CsmConfig(
codec_config=codec_model.config,
depth_decoder_config=depth_decoder_config,
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.5,
"low_freq_factor": 0.125,
"original_max_position_embeddings": 1024,
"rope_type": "llama3",
},
)
params = {
"backbone": {
"num_attention_heads": config.num_attention_heads,
"num_key_value_heads": config.num_key_value_heads,
"dim_per_head": config.head_dim,
"key_value_dim": config.head_dim * config.num_key_value_heads,
"dim": config.hidden_size,
},
"depth_decoder": {
"num_attention_heads": config.depth_decoder_config.num_attention_heads,
"num_key_value_heads": config.depth_decoder_config.num_key_value_heads,
"dim_per_head": config.depth_decoder_config.head_dim,
"key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads,
"dim": config.depth_decoder_config.hidden_size,
},
}
model_path = cached_file(
input_path_or_repo,
model_name,
)
print(f"Fetching all parameters from the checkpoint at {model_path}...")
loaded = torch.load(model_path, map_location="cpu")
print("Converting model...")
state_dict = {}
# -----------------------
# convert parameter names
# -----------------------
# Add codec_model. prefix to every key in the codec model state dict
codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()}
state_dict.update(codec_state_dict)
for key, value in loaded.items():
new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
current_parameter = value
# Post-process the current_parameter.
if re.search("(k|q)_proj.weight", new_key):
params_keys = "backbone" if "backbone" in new_key else "depth_decoder"
if "q_proj" in new_key:
num_heads = params[params_keys]["num_attention_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["dim"]
dim = params[params_keys]["dim"]
else:
num_heads = params[params_keys]["num_key_value_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["key_value_dim"]
dim = params[params_keys]["dim"]
current_parameter = permute_for_rope(value, num_heads, param_dim, dim)
state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim)
state_dict[new_key] = current_parameter
# add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights
state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[
"backbone_model.embed_tokens.embed_audio_tokens.weight"
].clone()
del loaded
gc.collect()
# -------------------------
# load the weights and save
# -------------------------
print("Loading the checkpoint in a Csm model.")
with torch.device("meta"):
model = CsmForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
# default generation config
model.generation_config._from_model_config = False
model.generation_config.max_new_tokens = 125
model.generation_config.do_sample = True
model.generation_config.top_k = 50
model.generation_config.temperature = 0.9
model.generation_config.depth_decoder_do_sample = True
model.generation_config.depth_decoder_top_k = 50
model.generation_config.depth_decoder_temperature = 0.9
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
def write_tokenizer(output_dir):
# from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36
def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
tokenizer = load_llama3_tokenizer()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(output_dir)
# manually modify in tokenizer_config.json
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
print(
"Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: "
)
print("""
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
""")
def write_processor(output_dir, codec_model_path_or_repo):
chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"
tokenizer = AutoTokenizer.from_pretrained(output_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo)
processor = CsmProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chat_template=chat_template,
)
processor.save_pretrained(output_dir)
print("Processor saved successfully.")
def main():
parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format")
parser.add_argument(
"--input_path_or_repo",
type=str,
required=True,
help="Path or repo containing Csm weights",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Name of the model in input_path_or_repo",
)
parser.add_argument(
"--codec_model_path_or_repo",
type=str,
required=True,
help="Path or repo containing the codec model",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.model_name,
args.codec_model_path_or_repo,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
write_tokenizer(args.output_dir)
write_processor(args.output_dir, args.codec_model_path_or_repo)
if __name__ == "__main__":
main()

View File

@ -1,362 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CvT checkpoints from the original repository.
URL: https://github.com/microsoft/CvT"""
import argparse
import json
from collections import OrderedDict
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
def embeddings(idx):
"""
The function helps in renaming embedding layer weights.
Args:
idx: stage number in original model
"""
embed = []
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
f"stage{idx}.patch_embed.proj.weight",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
f"stage{idx}.patch_embed.proj.bias",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
f"stage{idx}.patch_embed.norm.weight",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
f"stage{idx}.patch_embed.norm.bias",
)
)
return embed
def attention(idx, cnt):
"""
The function helps in renaming attention block layers weights.
Args:
idx: stage number in original model
cnt: count of blocks in each stage
"""
attention_weights = []
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
f"stage{idx}.blocks.{cnt}.attn.proj.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
f"stage{idx}.blocks.{cnt}.attn.proj.bias",
)
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
)
return attention_weights
def cls_token(idx):
"""
Function helps in renaming cls_token weights
"""
token = []
token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
return token
def final():
"""
Function helps in renaming final classification layer
"""
head = []
head.append(("layernorm.weight", "norm.weight"))
head.append(("layernorm.bias", "norm.bias"))
head.append(("classifier.weight", "head.weight"))
head.append(("classifier.bias", "head.bias"))
return head
def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
"""
Function to convert the microsoft cvt checkpoint to huggingface checkpoint
"""
img_labels_file = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text())
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
# For depth size 13 (13 = 1+2+10)
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
config.depth = [1, 2, 10]
# For depth size 21 (21 = 1+4+16)
elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
config.depth = [1, 4, 16]
# For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
else:
config.depth = [2, 2, 20]
config.num_heads = [3, 12, 16]
config.embed_dim = [192, 768, 1024]
model = CvtForImageClassification(config)
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
image_processor.size["shortest_edge"] = image_size
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"), weights_only=True)
huggingface_weights = OrderedDict()
list_of_state_dict = []
for idx in range(len(config.depth)):
if config.cls_token[idx]:
list_of_state_dict = list_of_state_dict + cls_token(idx)
list_of_state_dict = list_of_state_dict + embeddings(idx)
for cnt in range(config.depth[idx]):
list_of_state_dict = list_of_state_dict + attention(idx, cnt)
list_of_state_dict = list_of_state_dict + final()
for gg in list_of_state_dict:
print(gg)
for i in range(len(list_of_state_dict)):
huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
model.load_state_dict(huggingface_weights)
model.save_pretrained(pytorch_dump_folder)
image_processor.save_pretrained(pytorch_dump_folder)
# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cvt_model",
default="cvt-w24",
type=str,
help="Name of the cvt model you'd like to convert.",
)
parser.add_argument(
"--image_size",
default=384,
type=int,
help="Input Image Size",
)
parser.add_argument(
"--cvt_file_name",
default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
type=str,
help="Input Image Size",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)

View File

@ -1,689 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import re
from pathlib import Path
from typing import Optional
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import DFineConfig, DFineForObjectDetection, RTDetrImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_d_fine_config(model_name: str) -> DFineConfig:
config = DFineConfig()
config.num_labels = 80
repo_id = "huggingface/label-files"
filename = "object365-id2label.json" if "obj365" in model_name else "coco-detection-mmdet-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
config.backbone_config.layer_type = "basic"
config.backbone_config.embedding_size = 32
config.hidden_expansion = 1.0
config.decoder_layers = 6
if model_name in ["dfine_x_coco", "dfine_x_obj2coco", "dfine_x_obj365"]:
config.backbone_config.hidden_sizes = [256, 512, 1024, 2048]
config.backbone_config.stage_in_channels = [64, 128, 512, 1024]
config.backbone_config.stage_mid_channels = [64, 128, 256, 512]
config.backbone_config.stage_out_channels = [128, 512, 1024, 2048]
config.backbone_config.stage_num_blocks = [1, 2, 5, 2]
config.backbone_config.stage_downsample = [False, True, True, True]
config.backbone_config.stage_light_block = [False, False, True, True]
config.backbone_config.stage_kernel_size = [3, 3, 5, 5]
config.backbone_config.stage_numb_of_layers = [6, 6, 6, 6]
config.backbone_config.stem_channels = [3, 32, 64]
config.encoder_in_channels = [512, 1024, 2048]
config.encoder_hidden_dim = 384
config.encoder_ffn_dim = 2048
config.decoder_n_points = [3, 6, 3]
config.decoder_in_channels = [384, 384, 384]
if model_name == "dfine_x_obj365":
config.num_labels = 366
elif model_name in ["dfine_m_coco", "dfine_m_obj2coco", "dfine_m_obj365"]:
config.backbone_config.hidden_sizes = [192, 384, 768, 1536]
config.backbone_config.stem_channels = [3, 24, 32]
config.backbone_config.stage_in_channels = [32, 96, 384, 768]
config.backbone_config.stage_mid_channels = [32, 64, 128, 256]
config.backbone_config.stage_out_channels = [96, 384, 768, 1536]
config.backbone_config.stage_num_blocks = [1, 1, 3, 1]
config.backbone_config.stage_downsample = [False, True, True, True]
config.backbone_config.stage_light_block = [False, False, True, True]
config.backbone_config.stage_kernel_size = [3, 3, 5, 5]
config.backbone_config.stage_numb_of_layers = [4, 4, 4, 4]
config.decoder_layers = 4
config.decoder_n_points = [3, 6, 3]
config.encoder_in_channels = [384, 768, 1536]
config.backbone_config.use_learnable_affine_block = True
config.depth_mult = 0.67
if model_name == "dfine_m_obj365":
config.num_labels = 366
elif model_name in ["dfine_l_coco", "dfine_l_obj2coco_e25", "dfine_l_obj365"]:
config.backbone_config.hidden_sizes = [256, 512, 1024, 2048]
config.backbone_config.stem_channels = [3, 32, 48]
config.backbone_config.stage_in_channels = [48, 128, 512, 1024]
config.backbone_config.stage_mid_channels = [48, 96, 192, 384]
config.backbone_config.stage_out_channels = [128, 512, 1024, 2048]
config.backbone_config.stage_num_blocks = [1, 1, 3, 1]
config.backbone_config.stage_downsample = [False, True, True, True]
config.backbone_config.stage_light_block = [False, False, True, True]
config.backbone_config.stage_kernel_size = [3, 3, 5, 5]
config.backbone_config.stage_numb_of_layers = [6, 6, 6, 6]
config.encoder_ffn_dim = 1024
config.encoder_in_channels = [512, 1024, 2048]
config.decoder_n_points = [3, 6, 3]
if model_name == "dfine_l_obj365":
config.num_labels = 366
elif model_name in ["dfine_n_coco", "dfine_n_obj2coco_e25", "dfine_n_obj365"]:
config.backbone_config.hidden_sizes = [128, 256, 512, 1024]
config.backbone_config.stem_channels = [3, 16, 16]
config.backbone_config.stage_in_channels = [16, 64, 256, 512]
config.backbone_config.stage_mid_channels = [16, 32, 64, 128]
config.backbone_config.stage_out_channels = [64, 256, 512, 1024]
config.backbone_config.stage_num_blocks = [1, 1, 2, 1]
config.backbone_config.stage_downsample = [False, True, True, True]
config.backbone_config.stage_light_block = [False, False, True, True]
config.backbone_config.stage_kernel_size = [3, 3, 5, 5]
config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3]
config.backbone_config.out_indices = [3, 4]
config.backbone_config.use_learnable_affine_block = True
config.num_feature_levels = 2
config.encoder_ffn_dim = 512
config.encode_proj_layers = [1]
config.d_model = 128
config.encoder_hidden_dim = 128
config.decoder_ffn_dim = 512
config.encoder_in_channels = [512, 1024]
config.decoder_n_points = [6, 6]
config.decoder_in_channels = [128, 128]
config.feat_strides = [16, 32]
config.depth_mult = 0.5
config.decoder_layers = 3
config.hidden_expansion = 0.34
if model_name == "dfine_n_obj365":
config.num_labels = 366
else:
config.backbone_config.hidden_sizes = [128, 256, 512, 1024]
config.backbone_config.stem_channels = [3, 16, 16]
config.backbone_config.stage_in_channels = [16, 64, 256, 512]
config.backbone_config.stage_mid_channels = [16, 32, 64, 128]
config.backbone_config.stage_out_channels = [64, 256, 512, 1024]
config.backbone_config.stage_num_blocks = [1, 1, 2, 1]
config.backbone_config.stage_downsample = [False, True, True, True]
config.backbone_config.stage_light_block = [False, False, True, True]
config.backbone_config.stage_kernel_size = [3, 3, 5, 5]
config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3]
config.decoder_layers = 3
config.hidden_expansion = 0.5
config.depth_mult = 0.34
config.decoder_n_points = [3, 6, 3]
config.encoder_in_channels = [256, 512, 1024]
config.backbone_config.use_learnable_affine_block = True
if model_name == "dfine_s_obj365":
config.num_labels = 366
return config
def load_original_state_dict(repo_id, model_name):
directory_path = hf_hub_download(repo_id=repo_id, filename=f"{model_name}.pth")
original_state_dict = {}
model = torch.load(directory_path, map_location="cpu")["model"]
for key in model.keys():
original_state_dict[key] = model[key]
return original_state_dict
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Decoder base mappings
r"decoder.valid_mask": r"model.decoder.valid_mask",
r"decoder.anchors": r"model.decoder.anchors",
r"decoder.up": r"model.decoder.up",
r"decoder.reg_scale": r"model.decoder.reg_scale",
# Backbone stem mappings - including stem2a and stem2b
r"backbone.stem.stem1.conv.weight": r"model.backbone.model.embedder.stem1.convolution.weight",
r"backbone.stem.stem2a.conv.weight": r"model.backbone.model.embedder.stem2a.convolution.weight",
r"backbone.stem.stem2b.conv.weight": r"model.backbone.model.embedder.stem2b.convolution.weight",
r"backbone.stem.stem3.conv.weight": r"model.backbone.model.embedder.stem3.convolution.weight",
r"backbone.stem.stem4.conv.weight": r"model.backbone.model.embedder.stem4.convolution.weight",
# Stem normalization
r"backbone.stem.stem1.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.stem1.normalization.\1",
r"backbone.stem.stem2a.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.stem2a.normalization.\1",
r"backbone.stem.stem2b.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.stem2b.normalization.\1",
r"backbone.stem.stem3.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.stem3.normalization.\1",
r"backbone.stem.stem4.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.stem4.normalization.\1",
# Stem lab parameters - fixed with .lab in the path
r"backbone.stem.stem1.lab.(scale|bias)": r"model.backbone.model.embedder.stem1.lab.\1",
r"backbone.stem.stem2a.lab.(scale|bias)": r"model.backbone.model.embedder.stem2a.lab.\1",
r"backbone.stem.stem2b.lab.(scale|bias)": r"model.backbone.model.embedder.stem2b.lab.\1",
r"backbone.stem.stem3.lab.(scale|bias)": r"model.backbone.model.embedder.stem3.lab.\1",
r"backbone.stem.stem4.lab.(scale|bias)": r"model.backbone.model.embedder.stem4.lab.\1",
# Backbone stages mappings
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.convolution.weight",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.normalization.\4",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv1.conv.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.convolution.weight",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv2.conv.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.convolution.weight",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv1.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.normalization.\4",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv2.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.normalization.\4",
# Backbone stages aggregation
r"backbone.stages.(\d+).blocks.(\d+).aggregation.0.conv.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.0.convolution.weight",
r"backbone.stages.(\d+).blocks.(\d+).aggregation.1.conv.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.1.convolution.weight",
r"backbone.stages.(\d+).blocks.(\d+).aggregation.0.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.0.normalization.\3",
r"backbone.stages.(\d+).blocks.(\d+).aggregation.1.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.1.normalization.\3",
# Backbone stages lab parameters for aggregation
r"backbone.stages.(\d+).blocks.(\d+).aggregation.0.lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.0.lab.\3",
r"backbone.stages.(\d+).blocks.(\d+).aggregation.1.lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.1.lab.\3",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.lab.\4",
# Conv1/Conv2 layers with lab
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv1.lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.lab.\4",
r"backbone.stages.(\d+).blocks.(\d+).layers.(\d+).conv2.lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.lab.\4",
# Downsample with lab
r"backbone.stages.(\d+).downsample.lab.(scale|bias)": r"model.backbone.model.encoder.stages.\1.downsample.lab.\2",
# Backbone downsample
r"backbone.stages.(\d+).downsample.conv.weight": r"model.backbone.model.encoder.stages.\1.downsample.convolution.weight",
r"backbone.stages.(\d+).downsample.bn.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.downsample.normalization.\2",
# Encoder mappings
r"encoder.encoder.(\d+).layers.0.self_attn.out_proj.(weight|bias)": r"model.encoder.encoder.\1.layers.0.self_attn.out_proj.\2",
r"encoder.encoder.(\d+).layers.0.linear1.(weight|bias)": r"model.encoder.encoder.\1.layers.0.fc1.\2",
r"encoder.encoder.(\d+).layers.0.linear2.(weight|bias)": r"model.encoder.encoder.\1.layers.0.fc2.\2",
r"encoder.encoder.(\d+).layers.0.norm1.(weight|bias)": r"model.encoder.encoder.\1.layers.0.self_attn_layer_norm.\2",
r"encoder.encoder.(\d+).layers.0.norm2.(weight|bias)": r"model.encoder.encoder.\1.layers.0.final_layer_norm.\2",
# Encoder projections and convolutions
r"encoder.input_proj.(\d+).conv.weight": r"model.encoder_input_proj.\1.0.weight",
r"encoder.input_proj.(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder_input_proj.\1.1.\2",
r"encoder.lateral_convs.(\d+).conv.weight": r"model.encoder.lateral_convs.\1.conv.weight",
r"encoder.lateral_convs.(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.lateral_convs.\1.norm.\2",
# FPN blocks - complete structure
# Basic convolutions
r"encoder.fpn_blocks.(\d+).cv1.conv.weight": r"model.encoder.fpn_blocks.\1.conv1.conv.weight",
r"encoder.fpn_blocks.(\d+).cv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv1.norm.\2",
# CSP Rep1 path
r"encoder.fpn_blocks.(\d+).cv2.0.conv1.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.conv.weight",
r"encoder.fpn_blocks.(\d+).cv2.0.conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.norm.\2",
r"encoder.fpn_blocks.(\d+).cv2.0.conv2.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv2.conv.weight",
r"encoder.fpn_blocks.(\d+).cv2.0.conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv2.norm.\2",
r"encoder.fpn_blocks.(\d+).cv2.1.conv.weight": r"model.encoder.fpn_blocks.\1.conv2.conv.weight",
r"encoder.fpn_blocks.(\d+).cv2.1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv2.norm.\2",
# CSP Rep2 path
r"encoder.fpn_blocks.(\d+).cv3.0.conv1.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.conv.weight",
r"encoder.fpn_blocks.(\d+).cv3.0.conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.norm.\2",
r"encoder.fpn_blocks.(\d+).cv3.0.conv2.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv2.conv.weight",
r"encoder.fpn_blocks.(\d+).cv3.0.conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv2.norm.\2",
r"encoder.fpn_blocks.(\d+).cv3.1.conv.weight": r"model.encoder.fpn_blocks.\1.conv3.conv.weight",
r"encoder.fpn_blocks.(\d+).cv3.1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv3.norm.\2",
# Final conv
r"encoder.fpn_blocks.(\d+).cv4.conv.weight": r"model.encoder.fpn_blocks.\1.conv4.conv.weight",
r"encoder.fpn_blocks.(\d+).cv4.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv4.norm.\2",
# Bottlenecks for CSP Rep1
r"encoder.fpn_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv1.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight",
r"encoder.fpn_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3",
r"encoder.fpn_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv2.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight",
r"encoder.fpn_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3",
# Bottlenecks for CSP Rep2
r"encoder.fpn_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv1.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight",
r"encoder.fpn_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3",
r"encoder.fpn_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv2.conv.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight",
r"encoder.fpn_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3",
# PAN blocks - complete structure
# Basic convolutions
r"encoder.pan_blocks.(\d+).cv1.conv.weight": r"model.encoder.pan_blocks.\1.conv1.conv.weight",
r"encoder.pan_blocks.(\d+).cv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv1.norm.\2",
# CSP Rep1 path
r"encoder.pan_blocks.(\d+).cv2.0.conv1.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.conv.weight",
r"encoder.pan_blocks.(\d+).cv2.0.conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.norm.\2",
r"encoder.pan_blocks.(\d+).cv2.0.conv2.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv2.conv.weight",
r"encoder.pan_blocks.(\d+).cv2.0.conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv2.norm.\2",
r"encoder.pan_blocks.(\d+).cv2.1.conv.weight": r"model.encoder.pan_blocks.\1.conv2.conv.weight",
r"encoder.pan_blocks.(\d+).cv2.1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv2.norm.\2",
# CSP Rep2 path
r"encoder.pan_blocks.(\d+).cv3.0.conv1.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.conv.weight",
r"encoder.pan_blocks.(\d+).cv3.0.conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.norm.\2",
r"encoder.pan_blocks.(\d+).cv3.0.conv2.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv2.conv.weight",
r"encoder.pan_blocks.(\d+).cv3.0.conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv2.norm.\2",
r"encoder.pan_blocks.(\d+).cv3.1.conv.weight": r"model.encoder.pan_blocks.\1.conv3.conv.weight",
r"encoder.pan_blocks.(\d+).cv3.1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv3.norm.\2",
# Final conv
r"encoder.pan_blocks.(\d+).cv4.conv.weight": r"model.encoder.pan_blocks.\1.conv4.conv.weight",
r"encoder.pan_blocks.(\d+).cv4.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv4.norm.\2",
# Bottlenecks for CSP Rep1
r"encoder.pan_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv1.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight",
r"encoder.pan_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3",
r"encoder.pan_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv2.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight",
r"encoder.pan_blocks.(\d+).cv2.0.bottlenecks.(\d+).conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3",
# Bottlenecks for CSP Rep2
r"encoder.pan_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv1.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight",
r"encoder.pan_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv1.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3",
r"encoder.pan_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv2.conv.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight",
r"encoder.pan_blocks.(\d+).cv3.0.bottlenecks.(\d+).conv2.norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3",
# Downsample convolutions
r"encoder.downsample_convs.(\d+).0.cv(\d+).conv.weight": r"model.encoder.downsample_convs.\1.conv\2.conv.weight",
r"encoder.downsample_convs.(\d+).0.cv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.conv\2.norm.\3",
# Decoder layers
r"decoder.decoder.layers.(\d+).self_attn.out_proj.(weight|bias)": r"model.decoder.layers.\1.self_attn.out_proj.\2",
r"decoder.decoder.layers.(\d+).cross_attn.sampling_offsets.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.\2",
r"decoder.decoder.layers.(\d+).cross_attn.attention_weights.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.attention_weights.\2",
r"decoder.decoder.layers.(\d+).cross_attn.value_proj.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2",
r"decoder.decoder.layers.(\d+).cross_attn.output_proj.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2",
r"decoder.decoder.layers.(\d+).cross_attn.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale",
r"decoder.decoder.layers.(\d+).gateway.gate.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2",
r"decoder.decoder.layers.(\d+).gateway.norm.(weight|bias)": r"model.decoder.layers.\1.gateway.norm.\2",
r"decoder.decoder.layers.(\d+).norm1.(weight|bias)": r"model.decoder.layers.\1.self_attn_layer_norm.\2",
r"decoder.decoder.layers.(\d+).norm2.(weight|bias)": r"model.decoder.layers.\1.encoder_attn_layer_norm.\2",
r"decoder.decoder.layers.(\d+).norm3.(weight|bias)": r"model.decoder.layers.\1.final_layer_norm.\2",
r"decoder.decoder.layers.(\d+).linear1.(weight|bias)": r"model.decoder.layers.\1.fc1.\2",
r"decoder.decoder.layers.(\d+).linear2.(weight|bias)": r"model.decoder.layers.\1.fc2.\2",
# LQE layers
r"decoder.decoder.lqe_layers.(\d+).reg_conf.layers.(\d+).(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3",
# Decoder heads and projections
r"decoder.dec_score_head.(\d+).(weight|bias)": r"model.decoder.class_embed.\1.\2",
r"decoder.dec_bbox_head.(\d+).layers.(\d+).(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3",
r"decoder.pre_bbox_head.layers.(\d+).(weight|bias)": r"model.decoder.pre_bbox_head.layers.\1.\2",
r"decoder.input_proj.(\d+).conv.weight": r"model.decoder_input_proj.\1.0.weight",
r"decoder.input_proj.(\d+).norm.(weight|bias|running_mean|running_var)": r"model.decoder_input_proj.\1.1.\2",
# Other decoder components
r"decoder.denoising_class_embed.weight": r"model.denoising_class_embed.weight",
r"decoder.query_pos_head.layers.(\d+).(weight|bias)": r"model.decoder.query_pos_head.layers.\1.\2",
r"decoder.enc_output.proj.(weight|bias)": r"model.enc_output.0.\1",
r"decoder.enc_output.norm.(weight|bias)": r"model.enc_output.1.\1",
r"decoder.enc_score_head.(weight|bias)": r"model.enc_score_head.\1",
r"decoder.enc_bbox_head.layers.(\d+).(weight|bias)": r"model.enc_bbox_head.layers.\1.\2",
}
def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None):
# Use the mapping to rename keys
for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
for key in list(state_dict_keys.keys()):
new_key = re.sub(original_key, converted_key, key)
if new_key != key:
state_dict_keys[new_key] = state_dict_keys.pop(key)
return state_dict_keys
def read_in_q_k_v(state_dict, config, model_name):
prefix = ""
encoder_hidden_dim = config.encoder_hidden_dim
# first: transformer encoder
for i in range(config.encoder_layers):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[
:encoder_hidden_dim, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[
encoder_hidden_dim : 2 * encoder_hidden_dim, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[
encoder_hidden_dim : 2 * encoder_hidden_dim
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[
-encoder_hidden_dim:, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:]
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
for i in range(config.decoder_layers):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight", None)
in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias", None)
# next, add query, keys and values (in that order) to the state dict
if model_name in ["dfine_n_coco", "dfine_n_obj2coco_e25", "dfine_n_obj365"]:
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:128, :]
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:128]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:384, :]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:384]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-128:, :]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-128:]
else:
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_d_fine_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id):
"""
Copy/paste/tweak model's weights to our D-FINE structure.
"""
# load default config
config = get_d_fine_config(model_name)
state_dict = load_original_state_dict(repo_id, model_name)
state_dict.pop("decoder.valid_mask", None)
state_dict.pop("decoder.anchors", None)
model = DFineForObjectDetection(config)
logger.info(f"Converting model {model_name}...")
state_dict = convert_old_keys_to_new_keys(state_dict)
state_dict.pop("decoder.model.decoder.up", None)
state_dict.pop("decoder.model.decoder.reg_scale", None)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, config, model_name)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
for key in state_dict.copy().keys():
if key.endswith("num_batches_tracked"):
del state_dict[key]
# for two_stage
if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key):
state_dict[key.split("model.decoder.")[-1]] = state_dict[key]
# finally, create HuggingFace model and load state dict
model.load_state_dict(state_dict)
model.eval()
# load image processor
image_processor = RTDetrImageProcessor()
# prepare image
img = prepare_img()
# preprocess image
transformations = transforms.Compose(
[
transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension
encoding = image_processor(images=img, return_tensors="pt")
pixel_values = encoding["pixel_values"]
assert torch.allclose(original_pixel_values, pixel_values)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
pixel_values = pixel_values.to(device)
outputs = model(pixel_values)
if model_name == "dfine_x_coco":
expected_slice_logits = torch.tensor(
[
[-4.844723, -4.7293096, -4.5971327],
[-4.554266, -4.61723, -4.627926],
[-4.3934402, -4.6064143, -4.139952],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2565248, 0.5477609, 0.47644863],
[0.7690029, 0.41423926, 0.46148556],
[0.1688096, 0.19923759, 0.21118002],
]
)
elif model_name == "dfine_x_obj2coco":
expected_slice_logits = torch.tensor(
[
[-4.230433, -6.6295037, -4.8339615],
[-4.085411, -6.3280816, -4.695468],
[-3.8968022, -6.336813, -4.67051],
]
)
expected_slice_boxes = torch.tensor(
[
[0.25707328, 0.54842496, 0.47624254],
[0.76967394, 0.41272867, 0.45970756],
[0.16882066, 0.19918433, 0.2112098],
]
)
elif model_name == "dfine_x_obj365":
expected_slice_logits = torch.tensor(
[
[-6.3844957, -3.7549126, -4.6873264],
[-5.8433194, -3.4490552, -3.3228905],
[-6.5314736, -3.7856622, -4.895984],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7703046, 0.41329497, 0.45932162],
[0.16898105, 0.19876392, 0.21050783],
[0.25134972, 0.5517619, 0.4864124],
]
)
elif model_name == "dfine_m_coco":
expected_slice_logits = torch.tensor(
[
[-4.5187078, -4.71708, -4.117749],
[-4.513984, -4.937715, -3.829125],
[-4.830042, -6.931682, -3.1740026],
]
)
expected_slice_boxes = torch.tensor(
[
[0.25851426, 0.5489963, 0.4757598],
[0.769683, 0.41411665, 0.45988125],
[0.16866133, 0.19921188, 0.21207744],
]
)
elif model_name == "dfine_m_obj2coco":
expected_slice_logits = torch.tensor(
[
[-4.520666, -7.6678333, -5.739887],
[-4.5053635, -7.510611, -5.452532],
[-4.70348, -5.6098466, -5.0199957],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2567608, 0.5485795, 0.4767465],
[0.77035284, 0.41236404, 0.4580645],
[0.5498525, 0.27548885, 0.05886984],
]
)
elif model_name == "dfine_m_obj365":
expected_slice_logits = torch.tensor(
[
[-5.770525, -3.1610885, -5.2807794],
[-5.7809954, -3.768266, -5.1146393],
[-6.180705, -3.7357295, -3.1651964],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2529114, 0.5526663, 0.48270613],
[0.7712474, 0.41294736, 0.457174],
[0.5497157, 0.27588123, 0.05813372],
]
)
elif model_name == "dfine_l_coco":
expected_slice_logits = torch.tensor(
[
[-4.068779, -5.169955, -4.339212],
[-3.9461594, -5.0279613, -4.0161457],
[-4.218292, -6.196324, -5.175245],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2564867, 0.5489948, 0.4748876],
[0.7693534, 0.4138953, 0.4598034],
[0.16875696, 0.19875404, 0.21196914],
]
)
elif model_name == "dfine_l_obj365":
expected_slice_logits = torch.tensor(
[
[-5.7953215, -3.4901116, -5.4394145],
[-5.7032104, -3.671125, -5.76121],
[-6.09466, -3.1512096, -4.285499],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7693825, 0.41265628, 0.4606362],
[0.25306237, 0.55187637, 0.4832178],
[0.16892478, 0.19880727, 0.21115331],
]
)
elif model_name == "dfine_l_obj2coco_e25":
expected_slice_logits = torch.tensor(
[
[-3.6098495, -6.633563, -5.1227236],
[-3.682696, -6.9178205, -5.414557],
[-4.491674, -6.0823426, -4.5718226],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7697078, 0.41368833, 0.45879585],
[0.2573691, 0.54856044, 0.47715297],
[0.16895264, 0.19871138, 0.2115552],
]
)
elif model_name == "dfine_n_coco":
expected_slice_logits = torch.tensor(
[
[-3.7827945, -5.0889463, -4.8341026],
[-5.3046904, -6.2801714, -2.9276395],
[-4.497901, -5.2670407, -6.2380104],
]
)
expected_slice_boxes = torch.tensor(
[
[0.73334837, 0.4270624, 0.39424777],
[0.1680235, 0.1988639, 0.21031213],
[0.25370035, 0.5534435, 0.48496848],
]
)
elif model_name == "dfine_s_coco":
expected_slice_logits = torch.tensor(
[
[-3.8097816, -4.7724586, -5.994499],
[-5.2974715, -9.499067, -6.1653666],
[-5.3502765, -3.9530406, -6.3630295],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7677696, 0.41479152, 0.46441072],
[0.16912134, 0.19869131, 0.2123824],
[0.2581653, 0.54818195, 0.47512347],
]
)
elif model_name == "dfine_s_obj2coco":
expected_slice_logits = torch.tensor(
[
[-6.0208125, -7.532673, -5.0572147],
[-3.3595953, -9.057545, -6.376975],
[-4.3203554, -9.546032, -6.075504],
]
)
expected_slice_boxes = torch.tensor(
[
[0.16901012, 0.19883151, 0.21121952],
[0.76784194, 0.41266578, 0.46402973],
[00.2563128, 0.54797643, 0.47937632],
]
)
elif model_name == "dfine_s_obj365":
expected_slice_logits = torch.tensor(
[
[-6.3807316, -4.320986, -6.4775343],
[-6.5818424, -3.5009093, -5.75824],
[-5.748005, -4.3228016, -4.003726],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2532072, 0.5491191, 0.48222217],
[0.76586807, 0.41175705, 0.46789962],
[0.169111, 0.19844547, 0.21069047],
]
)
else:
raise ValueError(f"Unknown d_fine_name: {model_name}")
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-3)
assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-4)
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Upload model, image processor and config to the hub
logger.info("Uploading PyTorch model and image processor to the hub...")
config.push_to_hub(
repo_id=repo_id,
commit_message="Add config from convert_d_fine_original_pytorch_checkpoint_to_hf.py",
)
model.push_to_hub(
repo_id=repo_id,
commit_message="Add model from convert_d_fine_original_pytorch_checkpoint_to_hf.py",
)
image_processor.push_to_hub(
repo_id=repo_id,
commit_message="Add image processor from convert_d_fine_original_pytorch_checkpoint_to_hf.py",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="dfine_s_coco",
type=str,
help="model_name of the checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
parser.add_argument(
"--repo_id",
type=str,
help="repo_id where the model will be pushed to.",
)
args = parser.parse_args()
convert_d_fine_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id)

Some files were not shown because too many files have changed in this diff Show More