mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
13 Commits
serve-quan
...
v4.48.1
Author | SHA1 | Date | |
---|---|---|---|
2e752ead46 | |||
785b5cf444 | |||
3b09464364 | |||
b00807fac2 | |||
612bfd0801 | |||
6bc0fbcfa7 | |||
59e28c30fa | |||
7cf6230e25 | |||
d6f446ffa7 | |||
8ce1e9578a | |||
af2d7caff3 | |||
42b8e7916b | |||
e39c9f7a78 |
@ -505,7 +505,7 @@
|
||||
- local: model_doc/mobilebert
|
||||
title: MobileBERT
|
||||
- local: model_doc/modernbert
|
||||
title: ModernBERT
|
||||
title: ModernBert
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mpt
|
||||
@ -768,6 +768,8 @@
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/moonshine
|
||||
title: Moonshine
|
||||
- local: model_doc/moshi
|
||||
title: Moshi
|
||||
- local: model_doc/musicgen
|
||||
@ -858,6 +860,8 @@
|
||||
title: DePlot
|
||||
- local: model_doc/donut
|
||||
title: Donut
|
||||
- local: model_doc/emu3
|
||||
title: Emu3
|
||||
- local: model_doc/flava
|
||||
title: FLAVA
|
||||
- local: model_doc/git
|
||||
|
@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [EfficientFormer](model_doc/efficientformer) | ✅ | ✅ | ❌ |
|
||||
| [EfficientNet](model_doc/efficientnet) | ✅ | ❌ | ❌ |
|
||||
| [ELECTRA](model_doc/electra) | ✅ | ✅ | ✅ |
|
||||
| [Emu3](model_doc/emu3) | ✅ | ❌ | ❌ |
|
||||
| [EnCodec](model_doc/encodec) | ✅ | ❌ | ❌ |
|
||||
| [Encoder decoder](model_doc/encoder-decoder) | ✅ | ✅ | ✅ |
|
||||
| [ERNIE](model_doc/ernie) | ✅ | ❌ | ❌ |
|
||||
@ -235,6 +236,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
|
||||
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
|
||||
| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ |
|
||||
| [Moonshine](model_doc/moonshine) | ✅ | ❌ | ❌ |
|
||||
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
|
||||
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
|
||||
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |
|
||||
|
179
docs/source/en/model_doc/emu3.md
Normal file
179
docs/source/en/model_doc/emu3.md
Normal file
@ -0,0 +1,179 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Emu3
|
||||
|
||||
## Overview
|
||||
|
||||
The Emu3 model was proposed in [Emu3: Next-Token Prediction is All You Need](https://arxiv.org/abs/2409.18869) by Xinlong Wang, Xiaosong Zhang, Zhengxiong Luo, Quan Sun, Yufeng Cui, Jinsheng Wang, Fan Zhang, Yueze Wang, Zhen Li, Qiying Yu, Yingli Zhao, Yulong Ao, Xuebin Min, Tao Li, Boya Wu, Bo Zhao, Bowen Zhang, Liangdong Wang, Guang Liu, Zheqi He, Xi Yang, Jingjing Liu, Yonghua Lin, Tiejun Huang, Zhongyuan Wang.
|
||||
|
||||
Emu3 is a multimodal LLM that uses vector quantization to tokenize images into discrete tokens. Discretized image tokens are later fused with text token ids for image and text generation. The model can additionally generate images by predicting image token ids.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While next-token prediction is considered a promising path towards artificial general intelligence, it has struggled to excel in multimodal tasks, which are still dominated by diffusion models (e.g., Stable Diffusion) and compositional approaches (e.g., CLIP combined with LLMs). In this paper, we introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction. By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences. Emu3 outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship models such as SDXL and LLaVA-1.6, while eliminating the need for diffusion or compositional architectures. Emu3 is also capable of generating high-fidelity video via predicting the next token in a video sequence. We simplify complex multimodal model designs by converging on a singular focus: tokens, unlocking great potential for scaling both during training and inference. Our results demonstrate that next-token prediction is a promising path towards building general multimodal intelligence beyond language. We open-source key techniques and models to support further research in this direction.*
|
||||
|
||||
Tips:
|
||||
|
||||
- We advise users to set `processor.tokenizer.padding_side = "left"` before batched generation as it leads to more accurate results.
|
||||
|
||||
- Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
|
||||
|
||||
- Emu3 has two different checkpoints for image-generation and text-generation, make sure to use the correct checkpoint when loading the model. To generate an image, it is advised to use `prefix_constraints` so that the generated tokens are sampled only from possible image tokens. See more below for usage examples.
|
||||
|
||||
> [!TIP]
|
||||
> Emu3 implementation in Transformers uses a special image token to indicate where to merge image embeddings. The special image token isn't new and uses one of the reserved tokens: `<|extra_0|>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.
|
||||
|
||||
|
||||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||||
The original code can be found [here](https://github.com/baaivision/Emu3).
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
### Text generation inference
|
||||
|
||||
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`) to generate textual output from text or text and image inputs:
|
||||
|
||||
```python
|
||||
from transformers import Emu3Processor, Emu3ForConditionalGeneration
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf")
|
||||
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
|
||||
# prepare image and text prompt
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "What do you see in this image?<image>"
|
||||
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
# autoregressively complete prompt
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Image generation inference
|
||||
|
||||
Emu3 can also generate images from textual input. Here is how you can do it:
|
||||
|
||||
```python
|
||||
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Gen-hf")
|
||||
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Gen-hf", torch_dtype="bfloat16", device_map="auto", attn_implementation="flash_attention_2")
|
||||
|
||||
|
||||
inputs = processor(
|
||||
text=["a portrait of young girl. masterpiece, film grained, best quality.", "a dog running under the rain"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_for_image_generation=True,
|
||||
)
|
||||
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
|
||||
|
||||
neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
|
||||
neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
|
||||
|
||||
image_sizes = inputs.pop("image_sizes")
|
||||
HEIGHT, WIDTH = image_sizes[0]
|
||||
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, input_ids):
|
||||
height, width = HEIGHT, WIDTH
|
||||
visual_tokens = VISUAL_TOKENS
|
||||
image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device)
|
||||
eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device)
|
||||
eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device)
|
||||
pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device)
|
||||
eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device)
|
||||
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
|
||||
|
||||
position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0]
|
||||
offset = input_ids.shape[0] - position
|
||||
if offset % (width + 1) == 0:
|
||||
return (eol_token_id, )
|
||||
elif offset == (width + 1) * height + 1:
|
||||
return (eof_token_id, )
|
||||
elif offset == (width + 1) * height + 2:
|
||||
return (eoi_token_id, )
|
||||
elif offset == (width + 1) * height + 3:
|
||||
return (eos_token_id, )
|
||||
elif offset > (width + 1) * height + 3:
|
||||
return (pad_token_id, )
|
||||
else:
|
||||
return visual_tokens
|
||||
|
||||
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50_000, # make sure to have enough tokens for one image
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
return_dict_in_generate=True,
|
||||
negative_prompt_ids=neg_inputs.input_ids, # indicate for Classifier-Free Guidance
|
||||
negative_prompt_attention_mask=neg_inputs.attention_mask,
|
||||
)
|
||||
|
||||
image = model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH)
|
||||
images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image") # internally we convert to np but it's not supported in bf16 precision
|
||||
for i, image in enumerate(images['pixel_values']):
|
||||
image.save(f"result{i}.png")
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Emu3Config
|
||||
|
||||
[[autodoc]] Emu3Config
|
||||
|
||||
## Emu3VQVAEConfig
|
||||
|
||||
[[autodoc]] Emu3VQVAEConfig
|
||||
|
||||
## Emu3TextConfig
|
||||
|
||||
[[autodoc]] Emu3TextConfig
|
||||
|
||||
## Emu3Processor
|
||||
|
||||
[[autodoc]] Emu3Processor
|
||||
|
||||
## Emu3ImageProcessor
|
||||
|
||||
[[autodoc]] Emu3ImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Emu3VQVAE
|
||||
|
||||
[[autodoc]] Emu3VQVAE
|
||||
- forward
|
||||
|
||||
## Emu3TextModel
|
||||
|
||||
[[autodoc]] Emu3TextModel
|
||||
- forward
|
||||
|
||||
## Emu3ForCausalLM
|
||||
|
||||
[[autodoc]] Emu3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Emu3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Emu3ForConditionalGeneration
|
||||
- forward
|
56
docs/source/en/model_doc/moonshine.md
Normal file
56
docs/source/en/model_doc/moonshine.md
Normal file
@ -0,0 +1,56 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Moonshine
|
||||
|
||||
## Overview
|
||||
|
||||
The Moonshine model was proposed in [Moonshine: Speech Recognition for Live Transcription and Voice Commands
|
||||
](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper introduces Moonshine, a family of speech recognition models optimized for live transcription and voice command processing. Moonshine is based on an encoder-decoder transformer architecture and employs Rotary Position Embedding (RoPE) instead of traditional absolute position embeddings. The model is trained on speech segments of various lengths, but without using zero-padding, leading to greater efficiency for the encoder during inference time. When benchmarked against OpenAI's Whisper tiny-en, Moonshine Tiny demonstrates a 5x reduction in compute requirements for transcribing a 10-second speech segment while incurring no increase in word error rates across standard evaluation datasets. These results highlight Moonshine's potential for real-time and resource-constrained applications.*
|
||||
|
||||
Tips:
|
||||
|
||||
- Moonshine improves upon Whisper's architecture:
|
||||
1. It uses SwiGLU activation instead of GELU in the decoder layers
|
||||
2. Most importantly, it replaces absolute position embeddings with Rotary Position Embeddings (RoPE). This allows Moonshine to handle audio inputs of any length, unlike Whisper which is restricted to fixed 30-second windows.
|
||||
|
||||
This model was contributed by [Eustache Le Bihan (eustlb)](https://huggingface.co/eustlb).
|
||||
The original code can be found [here](https://github.com/usefulsensors/moonshine).
|
||||
|
||||
## Resources
|
||||
|
||||
- [Automatic speech recognition task guide](../tasks/asr)
|
||||
|
||||
## MoonshineConfig
|
||||
|
||||
[[autodoc]] MoonshineConfig
|
||||
|
||||
## MoonshineModel
|
||||
|
||||
[[autodoc]] MoonshineModel
|
||||
- forward
|
||||
- _mask_input_features
|
||||
|
||||
## MoonshineForConditionalGeneration
|
||||
|
||||
[[autodoc]] MoonshineForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
|
@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
@ -68,6 +69,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
|
||||
@ -244,6 +246,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
|
||||
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
@ -265,6 +268,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
|
||||
@ -283,8 +287,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
|
||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -437,7 +437,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.48.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.48.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.48.0.dev0"
|
||||
__version__ = "4.48.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -428,6 +428,12 @@ _import_structure = {
|
||||
"ElectraConfig",
|
||||
"ElectraTokenizer",
|
||||
],
|
||||
"models.emu3": [
|
||||
"Emu3Config",
|
||||
"Emu3Processor",
|
||||
"Emu3TextConfig",
|
||||
"Emu3VQVAEConfig",
|
||||
],
|
||||
"models.encodec": [
|
||||
"EncodecConfig",
|
||||
"EncodecFeatureExtractor",
|
||||
@ -610,6 +616,7 @@ _import_structure = {
|
||||
"models.mobilevit": ["MobileViTConfig"],
|
||||
"models.mobilevitv2": ["MobileViTV2Config"],
|
||||
"models.modernbert": ["ModernBertConfig"],
|
||||
"models.moonshine": ["MoonshineConfig"],
|
||||
"models.moshi": [
|
||||
"MoshiConfig",
|
||||
"MoshiDepthConfig",
|
||||
@ -1221,6 +1228,7 @@ else:
|
||||
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
|
||||
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
|
||||
_import_structure["models.efficientnet"].append("EfficientNetImageProcessor")
|
||||
_import_structure["models.emu3"].append("Emu3ImageProcessor")
|
||||
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
|
||||
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
|
||||
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
|
||||
@ -2242,6 +2250,15 @@ else:
|
||||
"load_tf_weights_in_electra",
|
||||
]
|
||||
)
|
||||
_import_structure["models.emu3"].extend(
|
||||
[
|
||||
"Emu3ForCausalLM",
|
||||
"Emu3ForConditionalGeneration",
|
||||
"Emu3PreTrainedModel",
|
||||
"Emu3TextModel",
|
||||
"Emu3VQVAE",
|
||||
]
|
||||
)
|
||||
_import_structure["models.encodec"].extend(
|
||||
[
|
||||
"EncodecModel",
|
||||
@ -2907,6 +2924,13 @@ else:
|
||||
"ModernBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moonshine"].extend(
|
||||
[
|
||||
"MoonshineForConditionalGeneration",
|
||||
"MoonshineModel",
|
||||
"MoonshinePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moshi"].extend(
|
||||
[
|
||||
"MoshiForCausalLM",
|
||||
@ -5432,6 +5456,12 @@ if TYPE_CHECKING:
|
||||
ElectraConfig,
|
||||
ElectraTokenizer,
|
||||
)
|
||||
from .models.emu3 import (
|
||||
Emu3Config,
|
||||
Emu3Processor,
|
||||
Emu3TextConfig,
|
||||
Emu3VQVAEConfig,
|
||||
)
|
||||
from .models.encodec import (
|
||||
EncodecConfig,
|
||||
EncodecFeatureExtractor,
|
||||
@ -5633,6 +5663,7 @@ if TYPE_CHECKING:
|
||||
MobileViTV2Config,
|
||||
)
|
||||
from .models.modernbert import ModernBertConfig
|
||||
from .models.moonshine import MoonshineConfig
|
||||
from .models.moshi import (
|
||||
MoshiConfig,
|
||||
MoshiDepthConfig,
|
||||
@ -6261,6 +6292,7 @@ if TYPE_CHECKING:
|
||||
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
|
||||
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
|
||||
from .models.efficientnet import EfficientNetImageProcessor
|
||||
from .models.emu3 import Emu3ImageProcessor
|
||||
from .models.flava import (
|
||||
FlavaFeatureExtractor,
|
||||
FlavaImageProcessor,
|
||||
@ -7130,6 +7162,13 @@ if TYPE_CHECKING:
|
||||
ElectraPreTrainedModel,
|
||||
load_tf_weights_in_electra,
|
||||
)
|
||||
from .models.emu3 import (
|
||||
Emu3ForCausalLM,
|
||||
Emu3ForConditionalGeneration,
|
||||
Emu3PreTrainedModel,
|
||||
Emu3TextModel,
|
||||
Emu3VQVAE,
|
||||
)
|
||||
from .models.encodec import (
|
||||
EncodecModel,
|
||||
EncodecPreTrainedModel,
|
||||
@ -7652,6 +7691,11 @@ if TYPE_CHECKING:
|
||||
ModernBertModel,
|
||||
ModernBertPreTrainedModel,
|
||||
)
|
||||
from .models.moonshine import (
|
||||
MoonshineForConditionalGeneration,
|
||||
MoonshineModel,
|
||||
MoonshinePreTrainedModel,
|
||||
)
|
||||
from .models.moshi import (
|
||||
MoshiForCausalLM,
|
||||
MoshiForConditionalGeneration,
|
||||
|
@ -1634,17 +1634,18 @@ class GenerationMixin:
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
|
||||
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
for idx in range(num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
return layer_device_map
|
||||
|
@ -27,7 +27,7 @@ def flex_attention_forward(
|
||||
if softcap is not None:
|
||||
score = softcap * torch.tanh(score / softcap)
|
||||
if causal_mask is not None:
|
||||
score += causal_mask[b][0][q_idx][kv_idx]
|
||||
score = score + causal_mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output, attention_weights = flex_attention(
|
||||
|
@ -86,6 +86,7 @@ from . import (
|
||||
dpt,
|
||||
efficientnet,
|
||||
electra,
|
||||
emu3,
|
||||
encodec,
|
||||
encoder_decoder,
|
||||
ernie,
|
||||
@ -170,6 +171,7 @@ from . import (
|
||||
mobilevit,
|
||||
mobilevitv2,
|
||||
modernbert,
|
||||
moonshine,
|
||||
moshi,
|
||||
mpnet,
|
||||
mpt,
|
||||
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = AlbertConfig.from_json_file(albert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = AlbertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
@ -1,389 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALIGN checkpoints from the original repository."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import align
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from transformers import (
|
||||
AlignConfig,
|
||||
AlignModel,
|
||||
AlignProcessor,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
EfficientNetConfig,
|
||||
EfficientNetImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
image = tf.image.resize(image, (346, 346))
|
||||
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
|
||||
return image
|
||||
|
||||
|
||||
def get_align_config():
|
||||
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
|
||||
vision_config.image_size = 289
|
||||
vision_config.hidden_dim = 640
|
||||
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
|
||||
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
|
||||
vision_config.depthwise_padding = []
|
||||
|
||||
text_config = BertConfig()
|
||||
config = AlignConfig.from_text_vision_configs(
|
||||
text_config=text_config, vision_config=vision_config, projection_dim=640
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def get_processor():
|
||||
image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
rescale_offset=True,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
tokenizer.model_max_length = 64
|
||||
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def rename_keys(original_param_names):
|
||||
# EfficientNet image encoder
|
||||
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
|
||||
block_names = list(set(block_names))
|
||||
block_names = sorted(block_names)
|
||||
num_blocks = len(block_names)
|
||||
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
|
||||
|
||||
rename_keys = []
|
||||
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
|
||||
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
|
||||
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
|
||||
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
|
||||
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
|
||||
|
||||
for b in block_names:
|
||||
hf_b = block_name_mapping[b]
|
||||
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
|
||||
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
|
||||
)
|
||||
|
||||
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
|
||||
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
|
||||
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
|
||||
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
|
||||
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
|
||||
)
|
||||
|
||||
key_mapping = {}
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = "vision_model." + item[1]
|
||||
|
||||
# BERT text encoder
|
||||
rename_keys = []
|
||||
old = "tf_bert_model/bert"
|
||||
new = "text_model"
|
||||
for i in range(12):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
|
||||
)
|
||||
|
||||
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
|
||||
)
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
|
||||
|
||||
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
|
||||
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
|
||||
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("temperature:0", "temperature"))
|
||||
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = item[1]
|
||||
return key_mapping
|
||||
|
||||
|
||||
def replace_params(hf_params, tf_params, key_mapping):
|
||||
list(hf_params.keys())
|
||||
|
||||
for key, value in tf_params.items():
|
||||
if key not in key_mapping:
|
||||
continue
|
||||
|
||||
hf_key = key_mapping[key]
|
||||
if "_conv" in key and "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
|
||||
elif "embeddings" in key:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
elif "depthwise_kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
|
||||
elif "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
# Replace HF parameters with original TF model parameters
|
||||
hf_params[hf_key].copy_(new_hf_value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ALIGN structure.
|
||||
"""
|
||||
# Load original model
|
||||
seq_length = 64
|
||||
tok = Tokenizer(seq_length)
|
||||
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
|
||||
original_model.compile()
|
||||
original_model.load_weights(checkpoint_path)
|
||||
|
||||
tf_params = original_model.trainable_variables
|
||||
tf_non_train_params = original_model.non_trainable_variables
|
||||
tf_params = {param.name: param.numpy() for param in tf_params}
|
||||
for param in tf_non_train_params:
|
||||
tf_params[param.name] = param.numpy()
|
||||
tf_param_names = list(tf_params.keys())
|
||||
|
||||
# Load HuggingFace model
|
||||
config = get_align_config()
|
||||
hf_model = AlignModel(config).eval()
|
||||
hf_params = hf_model.state_dict()
|
||||
|
||||
# Create src-to-dst parameter name mapping dictionary
|
||||
print("Converting parameters...")
|
||||
key_mapping = rename_keys(tf_param_names)
|
||||
replace_params(hf_params, tf_params, key_mapping)
|
||||
|
||||
# Initialize processor
|
||||
processor = get_processor()
|
||||
inputs = processor(
|
||||
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
|
||||
)
|
||||
|
||||
# HF model inference
|
||||
hf_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
|
||||
hf_image_features = outputs.image_embeds.detach().numpy()
|
||||
hf_text_features = outputs.text_embeds.detach().numpy()
|
||||
|
||||
# Original model inference
|
||||
original_model.trainable = False
|
||||
tf_image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
|
||||
text = tok(tf.constant(["A picture of a cat"]))
|
||||
|
||||
image_features = original_model.image_encoder(image, training=False)
|
||||
text_features = original_model.text_encoder(text, training=False)
|
||||
|
||||
image_features = tf.nn.l2_normalize(image_features, axis=-1)
|
||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||
|
||||
# Check whether original and HF model outputs match -> np.allclose
|
||||
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||
raise ValueError("The predicted image features are not the same.")
|
||||
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||
raise ValueError("The predicted text features are not the same.")
|
||||
print("Model outputs match!")
|
||||
|
||||
if save_model:
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
# Save converted model and image processor
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Push model and image processor to hub
|
||||
print("Pushing converted ALIGN to the hub...")
|
||||
processor.push_to_hub("align-base")
|
||||
hf_model.push_to_hub("align-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="./weights/model-weights",
|
||||
type=str,
|
||||
help="Path to the pretrained TF ALIGN checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="hf_model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EPILOG_TXT = """Example:
|
||||
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
|
||||
|
||||
Example for creating the old state dict file with Python:
|
||||
|
||||
import torch
|
||||
from aria.model.language_model.aria_llama import AriaTextForCausalLM
|
||||
|
||||
# load model
|
||||
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
|
||||
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
# load vision tower
|
||||
model.get_vision_tower().load_model()
|
||||
|
||||
# Save state dict
|
||||
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
|
||||
"""
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"vision_tower.vision_model": "vision_tower",
|
||||
"ln_ffn": "layer_norm",
|
||||
"ffn": "feed_forward",
|
||||
"ln_kv": "layer_norm_kv",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
|
||||
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_model_id,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|img|>",
|
||||
"pad_token": "<pad>",
|
||||
},
|
||||
)
|
||||
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
|
||||
processor = AriaProcessor.from_pretrained(
|
||||
text_model_id,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(text_model_id)
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_index = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
|
||||
}
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AriaForConditionalGeneration(config)
|
||||
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
# print("Saving models")
|
||||
# model.save_pretrained("local_aria", safe_serialization=False)
|
||||
# processor.save_pretrained("local_aria")
|
||||
print("Pushing to hub")
|
||||
model.push_to_hub(output_hub_path, create_pr=True)
|
||||
processor.push_to_hub(output_hub_path, create_pr=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=EPILOG_TXT,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the text model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the vision model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the converted model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_state_dict_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,279 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_audio_spectrogram_transformer_config(model_name):
|
||||
config = ASTConfig()
|
||||
|
||||
if "10-10" in model_name:
|
||||
pass
|
||||
elif "speech-commands" in model_name:
|
||||
config.max_length = 128
|
||||
elif "12-12" in model_name:
|
||||
config.time_stride = 12
|
||||
config.frequency_stride = 12
|
||||
elif "14-14" in model_name:
|
||||
config.time_stride = 14
|
||||
config.frequency_stride = 14
|
||||
elif "16-16" in model_name:
|
||||
config.time_stride = 16
|
||||
config.frequency_stride = 16
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "speech-commands" in model_name:
|
||||
config.num_labels = 35
|
||||
filename = "speech-commands-v2-id2label.json"
|
||||
else:
|
||||
config.num_labels = 527
|
||||
filename = "audioset-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "module.v" in name:
|
||||
name = name.replace("module.v", "audio_spectrogram_transformer")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "dist_token" in name:
|
||||
name = name.replace("dist_token", "embeddings.distillation_token")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
# transformer blocks
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
# final layernorm
|
||||
if "audio_spectrogram_transformer.norm" in name:
|
||||
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
|
||||
# classifier head
|
||||
if "module.mlp_head.0" in name:
|
||||
name = name.replace("module.mlp_head.0", "classifier.layernorm")
|
||||
if "module.mlp_head.1" in name:
|
||||
name = name.replace("module.mlp_head.1", "classifier.dense")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.hidden_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
|
||||
] = val[:dim, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
|
||||
] = val[dim : dim * 2, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
|
||||
] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
|
||||
] = val[:dim]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
|
||||
] = val[dim : dim * 2]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
|
||||
] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_keys(state_dict):
|
||||
ignore_keys = [
|
||||
"module.v.head.weight",
|
||||
"module.v.head.bias",
|
||||
"module.v.head_dist.weight",
|
||||
"module.v.head_dist.bias",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
|
||||
"""
|
||||
config = get_audio_spectrogram_transformer_config(model_name)
|
||||
|
||||
model_name_to_url = {
|
||||
"ast-finetuned-audioset-10-10-0.4593": (
|
||||
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.450": (
|
||||
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448": (
|
||||
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448-v2": (
|
||||
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-12-12-0.447": (
|
||||
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-14-14-0.443": (
|
||||
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-16-16-0.442": (
|
||||
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-speech-commands-v2": (
|
||||
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
|
||||
),
|
||||
}
|
||||
|
||||
# load original state_dict
|
||||
checkpoint_url = model_name_to_url[model_name]
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove some keys
|
||||
remove_keys(state_dict)
|
||||
# rename some keys
|
||||
new_state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
# load 🤗 model
|
||||
model = ASTForAudioClassification(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify outputs on dummy input
|
||||
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
|
||||
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
|
||||
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
|
||||
max_length = 1024 if "speech-commands" not in model_name else 128
|
||||
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
|
||||
filename="sample_audio.flac",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
waveform, _ = torchaudio.load(filepath)
|
||||
waveform = waveform.squeeze().numpy()
|
||||
|
||||
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
if model_name == "ast-finetuned-audioset-10-10-0.4593":
|
||||
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.450":
|
||||
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448":
|
||||
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
|
||||
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
|
||||
elif model_name == "ast-finetuned-audioset-12-12-0.447":
|
||||
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
|
||||
elif model_name == "ast-finetuned-audioset-14-14-0.443":
|
||||
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
|
||||
elif model_name == "ast-finetuned-audioset-16-16-0.442":
|
||||
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
|
||||
elif model_name == "ast-finetuned-speech-commands-v2":
|
||||
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError("Logits don't match")
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"MIT/{model_name}")
|
||||
feature_extractor.push_to_hub(f"MIT/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ast-finetuned-audioset-10-10-0.4593",
|
||||
type=str,
|
||||
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -103,6 +103,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("efficientformer", "EfficientFormerConfig"),
|
||||
("efficientnet", "EfficientNetConfig"),
|
||||
("electra", "ElectraConfig"),
|
||||
("emu3", "Emu3Config"),
|
||||
("encodec", "EncodecConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("ernie", "ErnieConfig"),
|
||||
@ -190,6 +191,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTConfig"),
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("modernbert", "ModernBertConfig"),
|
||||
("moonshine", "MoonshineConfig"),
|
||||
("moshi", "MoshiConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mpt", "MptConfig"),
|
||||
@ -419,6 +421,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("efficientformer", "EfficientFormer"),
|
||||
("efficientnet", "EfficientNet"),
|
||||
("electra", "ELECTRA"),
|
||||
("emu3", "Emu3"),
|
||||
("encodec", "EnCodec"),
|
||||
("encoder-decoder", "Encoder decoder"),
|
||||
("ernie", "ERNIE"),
|
||||
@ -519,6 +522,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilevit", "MobileViT"),
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("modernbert", "ModernBERT"),
|
||||
("moonshine", "Moonshine"),
|
||||
("moshi", "Moshi"),
|
||||
("mpnet", "MPNet"),
|
||||
("mpt", "MPT"),
|
||||
|
@ -73,6 +73,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
||||
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
("moonshine", "Wav2Vec2FeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
|
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTModel"),
|
||||
("mobilevitv2", "MobileViTV2Model"),
|
||||
("modernbert", "ModernBertModel"),
|
||||
("moonshine", "MoonshineModel"),
|
||||
("moshi", "MoshiModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mpt", "MptModel"),
|
||||
@ -436,6 +437,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
@ -497,6 +499,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("diffllama", "DiffLlamaForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("falcon_mamba", "FalconMambaForCausalLM"),
|
||||
@ -798,6 +801,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("blip", "BlipForConditionalGeneration"),
|
||||
("blip-2", "Blip2ForConditionalGeneration"),
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
("emu3", "Emu3ForConditionalGeneration"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("idefics", "IdeficsForVisionText2Text"),
|
||||
@ -937,6 +941,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
|
||||
@ -1425,6 +1430,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("electra", "ElectraModel"),
|
||||
("emu3", "Emu3TextModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
|
@ -59,6 +59,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clipseg", "CLIPSegProcessor"),
|
||||
("clvp", "ClvpProcessor"),
|
||||
("colpali", "ColPaliProcessor"),
|
||||
("emu3", "Emu3Processor"),
|
||||
("flava", "FlavaProcessor"),
|
||||
("fuyu", "FuyuProcessor"),
|
||||
("git", "GitProcessor"),
|
||||
@ -81,6 +82,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
|
@ -186,6 +186,7 @@ else:
|
||||
),
|
||||
),
|
||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("esm", ("EsmTokenizer", None)),
|
||||
@ -321,6 +322,7 @@ else:
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
@ -1,273 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: Dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: str = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
)
|
@ -1,263 +0,0 @@
|
||||
"""Convert Bark checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from bark.generation import _load_model as _bark_load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import EncodecConfig, EncodecModel, set_seed
|
||||
from transformers.models.bark.configuration_bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
BarkFineConfig,
|
||||
BarkSemanticConfig,
|
||||
)
|
||||
from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkCoarseGenerationConfig,
|
||||
BarkFineGenerationConfig,
|
||||
BarkGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
set_seed(770)
|
||||
|
||||
|
||||
new_layer_name_dict = {
|
||||
"c_attn": "att_proj",
|
||||
"c_proj": "out_proj",
|
||||
"c_fc": "in_proj",
|
||||
"transformer.": "",
|
||||
"h.": "layers.",
|
||||
"ln_1": "layernorm_1",
|
||||
"ln_2": "layernorm_2",
|
||||
"ln_f": "layernorm_final",
|
||||
"wpe": "position_embeds_layer",
|
||||
"wte": "input_embeds_layer",
|
||||
}
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text.pt",
|
||||
},
|
||||
"coarse_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse.pt",
|
||||
},
|
||||
"fine_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine.pt",
|
||||
},
|
||||
"text": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text_2.pt",
|
||||
},
|
||||
"coarse": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse_2.pt",
|
||||
},
|
||||
"fine": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine_2.pt",
|
||||
},
|
||||
}
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
key = model_type
|
||||
if use_small:
|
||||
key += "_small"
|
||||
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if model_type == "text":
|
||||
ModelClass = BarkSemanticModel
|
||||
ConfigClass = BarkSemanticConfig
|
||||
GenerationConfigClass = BarkSemanticGenerationConfig
|
||||
elif model_type == "coarse":
|
||||
ModelClass = BarkCoarseModel
|
||||
ConfigClass = BarkCoarseConfig
|
||||
GenerationConfigClass = BarkCoarseGenerationConfig
|
||||
elif model_type == "fine":
|
||||
ModelClass = BarkFineModel
|
||||
ConfigClass = BarkFineConfig
|
||||
GenerationConfigClass = BarkFineGenerationConfig
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"])
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
|
||||
# convert Bark model arguments to HF Bark model arguments
|
||||
model_args["num_heads"] = model_args.pop("n_head")
|
||||
model_args["hidden_size"] = model_args.pop("n_embd")
|
||||
model_args["num_layers"] = model_args.pop("n_layer")
|
||||
|
||||
model_config = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(config=model_config)
|
||||
model_generation_config = GenerationConfigClass()
|
||||
|
||||
model.generation_config = model_generation_config
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
# replace part of the key with corresponding layer name in HF implementation
|
||||
new_k = k[len(unwanted_prefix) :]
|
||||
for old_layer_name in new_layer_name_dict:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.num_parameters(exclude_embeddings=True)
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
|
||||
device = "cpu" # do conversion on cpu
|
||||
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
|
||||
|
||||
# load bark initial model
|
||||
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
|
||||
|
||||
if model_type == "text":
|
||||
bark_model = bark_model["model"]
|
||||
|
||||
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
|
||||
raise ValueError("initial and new models don't have the same number of parameters")
|
||||
|
||||
# check if same output as the bark model
|
||||
batch_size = 5
|
||||
sequence_length = 10
|
||||
|
||||
if model_type in ["text", "coarse"]:
|
||||
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
|
||||
output_old_model = bark_model(vec)[0]
|
||||
|
||||
output_new_model_total = model(vec)
|
||||
|
||||
# take last logits
|
||||
output_new_model = output_new_model_total.logits[:, [-1], :]
|
||||
|
||||
else:
|
||||
prediction_codeboook_channel = 3
|
||||
n_codes_total = 8
|
||||
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
|
||||
|
||||
output_new_model_total = model(prediction_codeboook_channel, vec)
|
||||
output_old_model = bark_model(prediction_codeboook_channel, vec)
|
||||
|
||||
output_new_model = output_new_model_total.logits
|
||||
|
||||
# output difference should come from the difference of self-attention implementation design
|
||||
if output_new_model.shape != output_old_model.shape:
|
||||
raise ValueError("initial and new outputs don't have the same shape")
|
||||
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
|
||||
raise ValueError("initial and new outputs are not equal")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_whole_bark_model(
|
||||
semantic_path,
|
||||
coarse_path,
|
||||
fine_path,
|
||||
append_text,
|
||||
hub_path,
|
||||
folder_path,
|
||||
):
|
||||
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
|
||||
|
||||
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
|
||||
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
|
||||
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
|
||||
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
semantic = BarkSemanticModel.from_pretrained(semantic_path)
|
||||
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
|
||||
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
|
||||
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
bark_config = BarkConfig.from_sub_model_configs(
|
||||
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
|
||||
)
|
||||
|
||||
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
|
||||
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
|
||||
)
|
||||
|
||||
bark = BarkModel(bark_config)
|
||||
|
||||
bark.semantic = semantic
|
||||
bark.coarse_acoustics = coarseAcoustic
|
||||
bark.fine_acoustics = fineAcoustic
|
||||
bark.codec_model = codec
|
||||
|
||||
bark.generation_config = bark_generation_config
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def load_xsum_checkpoint(checkpoint_path):
|
||||
"""Checkpoint path should end in model.pt"""
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
||||
hub_interface.model.load_state_dict(sd["model"])
|
||||
return hub_interface
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
||||
else:
|
||||
bart = load_xsum_checkpoint(checkpoint_path)
|
||||
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
if hf_checkpoint_name is None:
|
||||
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
if not torch.eq(tokens, tokens2).all():
|
||||
raise ValueError(
|
||||
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||
)
|
||||
|
||||
if checkpoint_path == "bart.large.mnli":
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
new_model_outputs = model(tokens)[0] # logits
|
||||
else: # no classification heads to worry about
|
||||
state_dict = bart.model.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
else:
|
||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
new_model_outputs = model.model(tokens)[0]
|
||||
|
||||
# Check results
|
||||
if fairseq_output.shape != new_model_outputs.shape:
|
||||
raise ValueError(
|
||||
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||
)
|
||||
if (fairseq_output != new_model_outputs).any().item():
|
||||
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||
)
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
@ -1,373 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BEiT checkpoints from the unilm repository."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitImageProcessor,
|
||||
)
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + shared relative position bias + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
(
|
||||
"rel_pos_bias.relative_position_bias_table",
|
||||
"beit.encoder.relative_position_bias.relative_position_bias_table",
|
||||
),
|
||||
(
|
||||
"rel_pos_bias.relative_position_index",
|
||||
"beit.encoder.relative_position_bias.relative_position_index",
|
||||
),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
# masked image modeling
|
||||
config.use_shared_relative_position_bias = True
|
||||
config.use_mask_token = True
|
||||
has_lm_head = True
|
||||
elif checkpoint_url[-9:-4] == "ft22k":
|
||||
# intermediate fine-tuning on ImageNet-22k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
elif checkpoint_url[-8:-4] == "to1k":
|
||||
# fine-tuning on ImageNet-1k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
if "384" in checkpoint_url:
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
pass
|
||||
elif "large" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
|
||||
)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size([1, 1000])
|
||||
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
|
||||
expected_class_idx = 2397
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
|
||||
expected_class_idx = 2396
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
|
||||
expected_class_idx = 285
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
|
||||
expected_class_idx = 281
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
|
||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
|
||||
|
||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
|
||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
|
||||
|
||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||
|
||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
|
||||
Models trained with never versions are not compatible with this script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
layer_depth = []
|
||||
for full_name, shape in init_vars:
|
||||
# logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
name = full_name.split("/")
|
||||
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
|
||||
logger.info(f"Skipping non-model layer {full_name}")
|
||||
continue
|
||||
if "optimizer" in full_name:
|
||||
logger.info(f"Skipping optimization layer {full_name}")
|
||||
continue
|
||||
if name[0] == "model":
|
||||
# ignore initial 'model'
|
||||
name = name[1:]
|
||||
# figure out how many levels deep the name is
|
||||
depth = 0
|
||||
for _name in name:
|
||||
if _name.startswith("layer_with_weights"):
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
layer_depth.append(depth)
|
||||
# read data
|
||||
array = tf.train.load_variable(tf_path, full_name)
|
||||
names.append("/".join(name))
|
||||
arrays.append(array)
|
||||
logger.info(f"Read a total of {len(arrays):,} layers")
|
||||
|
||||
# Sanity check
|
||||
if len(set(layer_depth)) != 1:
|
||||
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
|
||||
layer_depth = list(set(layer_depth))[0]
|
||||
if layer_depth != 1:
|
||||
raise ValueError(
|
||||
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
|
||||
" heads."
|
||||
)
|
||||
|
||||
# convert layers
|
||||
logger.info("Converting weights...")
|
||||
for full_name, array in zip(names, arrays):
|
||||
name = full_name.split("/")
|
||||
pointer = model
|
||||
trace = []
|
||||
for i, m_name in enumerate(name):
|
||||
if m_name == ".ATTRIBUTES":
|
||||
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
|
||||
break
|
||||
if m_name.startswith("layer_with_weights"):
|
||||
layer_num = int(m_name.split("-")[-1])
|
||||
if layer_num <= 2:
|
||||
# embedding layers
|
||||
# layer_num 0: word_embeddings
|
||||
# layer_num 1: position_embeddings
|
||||
# layer_num 2: token_type_embeddings
|
||||
continue
|
||||
elif layer_num == 3:
|
||||
# embedding LayerNorm
|
||||
trace.extend(["embeddings", "LayerNorm"])
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
|
||||
# encoder layers
|
||||
trace.extend(["encoder", "layer", str(layer_num - 4)])
|
||||
pointer = getattr(pointer, "encoder")
|
||||
pointer = getattr(pointer, "layer")
|
||||
pointer = pointer[layer_num - 4]
|
||||
elif layer_num == config.num_hidden_layers + 4:
|
||||
# pooler layer
|
||||
trace.extend(["pooler", "dense"])
|
||||
pointer = getattr(pointer, "pooler")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "embeddings":
|
||||
trace.append("embeddings")
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
if layer_num == 0:
|
||||
trace.append("word_embeddings")
|
||||
pointer = getattr(pointer, "word_embeddings")
|
||||
elif layer_num == 1:
|
||||
trace.append("position_embeddings")
|
||||
pointer = getattr(pointer, "position_embeddings")
|
||||
elif layer_num == 2:
|
||||
trace.append("token_type_embeddings")
|
||||
pointer = getattr(pointer, "token_type_embeddings")
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding layer with name {full_name}")
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "_attention_layer":
|
||||
# self-attention layer
|
||||
trace.extend(["attention", "self"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "self")
|
||||
elif m_name == "_attention_layer_norm":
|
||||
# output attention norm
|
||||
trace.extend(["attention", "output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_attention_output_dense":
|
||||
# output attention dense
|
||||
trace.extend(["attention", "output", "dense"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_dense":
|
||||
# output dense
|
||||
trace.extend(["output", "dense"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output dense
|
||||
trace.extend(["output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_key_dense":
|
||||
# attention key
|
||||
trace.append("key")
|
||||
pointer = getattr(pointer, "key")
|
||||
elif m_name == "_query_dense":
|
||||
# attention query
|
||||
trace.append("query")
|
||||
pointer = getattr(pointer, "query")
|
||||
elif m_name == "_value_dense":
|
||||
# attention value
|
||||
trace.append("value")
|
||||
pointer = getattr(pointer, "value")
|
||||
elif m_name == "_intermediate_dense":
|
||||
# attention intermediate dense
|
||||
trace.extend(["intermediate", "dense"])
|
||||
pointer = getattr(pointer, "intermediate")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output layer norm
|
||||
trace.append("output")
|
||||
pointer = getattr(pointer, "output")
|
||||
# weights & biases
|
||||
elif m_name in ["bias", "beta"]:
|
||||
trace.append("bias")
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif m_name in ["kernel", "gamma"]:
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
logger.warning(f"Ignored {m_name}")
|
||||
# for certain layers reshape is necessary
|
||||
trace = ".".join(trace)
|
||||
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
|
||||
r"(\S+)\.attention\.output\.dense\.weight", trace
|
||||
):
|
||||
array = array.reshape(pointer.data.shape)
|
||||
if "kernel" in full_name:
|
||||
array = array.transpose()
|
||||
if pointer.shape == array.shape:
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
|
||||
f" {array.shape}"
|
||||
)
|
||||
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
|
||||
return model
|
||||
|
||||
|
||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
|
||||
# Instantiate model
|
||||
logger.info(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from checkpoint
|
||||
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
|
||||
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
|
||||
|
||||
# Save pytorch-model
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model (must include filename).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,112 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model: BertModel Pytorch model instance to be converted
|
||||
ckpt_dir: Tensorflow model directory
|
||||
model_name: model name
|
||||
|
||||
Currently supported HF models:
|
||||
|
||||
- Y BertModel
|
||||
- N BertForMaskedLM
|
||||
- N BertForPreTraining
|
||||
- N BertForMultipleChoice
|
||||
- N BertForNextSentencePrediction
|
||||
- N BertForSequenceClassification
|
||||
- N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return f"bert/{name}"
|
||||
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
session.run(tf_var)
|
||||
return tf_var
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Session() as session:
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any(x in var_name for x in tensors_to_transpose):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
|
||||
tf_weight = session.run(tf_var)
|
||||
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
|
||||
|
||||
saver = tf.train.Saver(tf.trainable_variables())
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,188 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
|
||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
|
||||
|
||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertPooler,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
|
||||
def get_masked_lm_array(name: str):
|
||||
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_array(name: str):
|
||||
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_layer_array(layer_index: int, name: str):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
array = array.reshape(orginal_shape)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
print(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
# Layers
|
||||
for layer_index in range(0, config.num_hidden_layers):
|
||||
layer: BertLayer = model.bert.encoder.layer[layer_index]
|
||||
|
||||
# Self-attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
|
||||
self_attn.query.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
|
||||
)
|
||||
self_attn.query.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
|
||||
)
|
||||
self_attn.key.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
|
||||
)
|
||||
self_attn.key.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
|
||||
)
|
||||
self_attn.value.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
|
||||
)
|
||||
self_attn.value.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
|
||||
)
|
||||
|
||||
# Self-attention Output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
|
||||
self_output.dense.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
|
||||
)
|
||||
self_output.dense.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
|
||||
)
|
||||
|
||||
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
|
||||
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
|
||||
|
||||
# Intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
|
||||
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
|
||||
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
|
||||
|
||||
# Output
|
||||
bert_output: BertOutput = layer.output
|
||||
|
||||
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
|
||||
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
|
||||
|
||||
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
|
||||
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
|
||||
|
||||
# Embeddings
|
||||
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
|
||||
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
|
||||
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
|
||||
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
|
||||
|
||||
# LM Head
|
||||
lm_head = model.cls.predictions.transform
|
||||
|
||||
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
|
||||
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
|
||||
|
||||
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
|
||||
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
|
||||
|
||||
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
|
||||
|
||||
# Pooling
|
||||
model.bert.pooler = BertPooler(config=config)
|
||||
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
|
||||
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
|
||||
|
||||
# Export final model
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Integration test - should load without any errors ;)
|
||||
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
|
||||
print(new_model.eval())
|
||||
|
||||
print("Model conversion was done sucessfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,69 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigBird checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
|
||||
# Initialise PyTorch model
|
||||
config = BigBirdConfig.from_json_file(big_bird_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
if is_trivia_qa:
|
||||
model = BigBirdForQuestionAnswering(config)
|
||||
else:
|
||||
model = BigBirdForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big_bird_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
|
||||
)
|
@ -1,170 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
|
||||
|
||||
|
||||
INIT_COMMON = [
|
||||
# tf -> hf
|
||||
("/", "."),
|
||||
("layer_", "layers."),
|
||||
("kernel", "weight"),
|
||||
("beta", "bias"),
|
||||
("gamma", "weight"),
|
||||
("pegasus", "model"),
|
||||
]
|
||||
END_COMMON = [
|
||||
(".output.dense", ".fc2"),
|
||||
("intermediate.LayerNorm", "final_layer_norm"),
|
||||
("intermediate.dense", "fc1"),
|
||||
]
|
||||
|
||||
DECODER_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.out_proj"),
|
||||
("attention.self", "self_attn"),
|
||||
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
|
||||
("attention.encdec_output.dense", "encoder_attn.out_proj"),
|
||||
("attention.encdec", "encoder_attn"),
|
||||
("key", "k_proj"),
|
||||
("value", "v_proj"),
|
||||
("query", "q_proj"),
|
||||
("decoder.LayerNorm", "decoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
REMAINING_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("embeddings.word_embeddings", "shared.weight"),
|
||||
("embeddings.position_embeddings", "embed_positions.weight"),
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.output"),
|
||||
("attention.self", "self_attn.self"),
|
||||
("encoder.LayerNorm", "encoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
KEYS_TO_IGNORE = [
|
||||
"encdec/key/bias",
|
||||
"encdec/query/bias",
|
||||
"encdec/value/bias",
|
||||
"self/key/bias",
|
||||
"self/query/bias",
|
||||
"self/value/bias",
|
||||
"encdec_output/dense/bias",
|
||||
"attention/output/dense/bias",
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k, patterns):
|
||||
for tf_name, hf_name in patterns:
|
||||
k = k.replace(tf_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
|
||||
cfg = BigBirdPegasusConfig(**config_update)
|
||||
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
|
||||
state_dict = torch_model.state_dict()
|
||||
mapping = {}
|
||||
|
||||
# separating decoder weights
|
||||
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
|
||||
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
|
||||
|
||||
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = DECODER_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = REMAINING_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
|
||||
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
|
||||
missing, extra = torch_model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if k
|
||||
not in [
|
||||
"final_logits_bias",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path) -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any(pat in name for pat in ignore_name)
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
config_update = {}
|
||||
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BioGptConfig, BioGptForCausalLM
|
||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
|
||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(f, "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
|
||||
# prep
|
||||
if not os.path.exists(biogpt_checkpoint_path):
|
||||
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
|
||||
chkpt = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
args = chkpt["cfg"]["model"]
|
||||
|
||||
# dicts
|
||||
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
|
||||
if not os.path.isfile(dict_file):
|
||||
raise ValueError(f"path to the file {dict_file} does not exist!")
|
||||
src_dict = Dictionary.load(dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
|
||||
if not os.path.isfile(bpecodes_file):
|
||||
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
|
||||
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
shutil.copyfile(bpecodes_file, merges_file)
|
||||
|
||||
# model config
|
||||
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
model_conf = {
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"architectures": ["BioGptForCausalLM"],
|
||||
"attention_probs_dropout_prob": args["attention_dropout"],
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": args["activation_fn"],
|
||||
"hidden_dropout_prob": args["dropout"],
|
||||
"hidden_size": args["decoder_embed_dim"],
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": args["decoder_ffn_embed_dim"],
|
||||
"layer_norm_eps": 1e-12,
|
||||
"layerdrop": args["decoder_layerdrop"],
|
||||
"max_position_embeddings": args["max_target_positions"],
|
||||
"model_type": "biogpt",
|
||||
"num_attention_heads": args["decoder_attention_heads"],
|
||||
"num_hidden_layers": args["decoder_layers"],
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_decoder_input_output_embed"],
|
||||
"vocab_size": src_vocab_size,
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
|
||||
print(f"Generating {biogpt_model_config_file}")
|
||||
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": "<pad>",
|
||||
"special_tokens_map_file": None,
|
||||
"tokenizer_class": "BioGptTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
print(f"Generating {biogpt_tokenizer_config_file}")
|
||||
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model_state_dict = chkpt["model"]
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"decoder.version",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
layer_names = list(model_state_dict.keys())
|
||||
for layer_name in layer_names:
|
||||
if layer_name.endswith("output_projection.weight"):
|
||||
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
|
||||
else:
|
||||
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
|
||||
|
||||
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = BioGptForCausalLM(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--biogpt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
|
||||
" bpecodes, etc."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,177 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BiT checkpoints from the timm library."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
|
||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
conv_layer = "std_conv" if "bit" in model_name else False
|
||||
|
||||
# note that when using BiT as backbone for ViT-hybrid checkpoints,
|
||||
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
|
||||
# config.conv_layer = "std_conv_same"
|
||||
config = BitConfig(
|
||||
conv_layer=conv_layer,
|
||||
num_labels=1000,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "stem.conv" in name:
|
||||
name = name.replace("stem.conv", "bit.embedder.convolution")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "layers")
|
||||
if "head.fc" in name:
|
||||
name = name.replace("head.fc", "classifier.1")
|
||||
if name.startswith("norm"):
|
||||
name = "bit." + name
|
||||
if "bit" not in name and "classifier" not in name:
|
||||
name = "bit.encoder." + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BiT structure.
|
||||
"""
|
||||
|
||||
# define default BiT configuration
|
||||
config = get_config(model_name)
|
||||
|
||||
# load original model from timm
|
||||
timm_model = create_model(model_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model
|
||||
state_dict = timm_model.state_dict()
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
|
||||
|
||||
# load HuggingFace model
|
||||
model = BitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create image processor
|
||||
transform = create_transform(**resolve_data_config({}, model=timm_model))
|
||||
timm_transforms = transform.transforms
|
||||
|
||||
pillow_resamplings = {
|
||||
"bilinear": PILImageResampling.BILINEAR,
|
||||
"bicubic": PILImageResampling.BICUBIC,
|
||||
"nearest": PILImageResampling.NEAREST,
|
||||
}
|
||||
|
||||
processor = BitImageProcessor(
|
||||
do_resize=True,
|
||||
size={"shortest_edge": timm_transforms[0].size},
|
||||
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
|
||||
do_center_crop=True,
|
||||
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
|
||||
do_normalize=True,
|
||||
image_mean=timm_transforms[-1].mean.tolist(),
|
||||
image_std=timm_transforms[-1].std.tolist(),
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
timm_pixel_values = transform(image).unsqueeze(0)
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# verify pixel values
|
||||
assert torch.allclose(timm_pixel_values, pixel_values)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Logits:", logits[0, :3])
|
||||
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model {model_name} and processor to the hub")
|
||||
model.push_to_hub(f"ybelkada/{model_name}")
|
||||
processor.push_to_hub(f"ybelkada/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="resnetv2_50x1_bitm",
|
||||
type=str,
|
||||
help="Name of the BiT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,114 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -1,191 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# git clone https://github.com/salesforce/BLIP.git
|
||||
from models.blip import blip_decoder
|
||||
from models.blip_itm import blip_itm
|
||||
from models.blip_vqa import blip_vqa
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BlipConfig,
|
||||
BlipForConditionalGeneration,
|
||||
BlipForImageTextRetrieval,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
def load_demo_image(image_size, device):
|
||||
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
image = transform(raw_image).unsqueeze(0).to(device)
|
||||
return image
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
if "visual_encoder" in key:
|
||||
key = re.sub("visual_encoder*", "vision_model.encoder", key)
|
||||
if "blocks" in key:
|
||||
key = re.sub(r"blocks", "layers", key)
|
||||
if "attn" in key:
|
||||
key = re.sub(r"attn", "self_attn", key)
|
||||
if "norm1" in key:
|
||||
key = re.sub(r"norm1", "layer_norm1", key)
|
||||
if "norm2" in key:
|
||||
key = re.sub(r"norm2", "layer_norm2", key)
|
||||
if "encoder.norm" in key:
|
||||
key = re.sub(r"encoder.norm", "post_layernorm", key)
|
||||
if "encoder.patch_embed.proj" in key:
|
||||
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
|
||||
|
||||
if "encoder.pos_embed" in key:
|
||||
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
|
||||
if "encoder.cls_token" in key:
|
||||
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
|
||||
|
||||
if "self_attn" in key:
|
||||
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = BlipConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = BlipForConditionalGeneration(config).eval()
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
|
||||
|
||||
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
modified_state_dict = pt_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_model.load_state_dict(modified_state_dict)
|
||||
|
||||
image_size = 384
|
||||
image = load_demo_image(image_size=image_size, device="cpu")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
input_ids = tokenizer(["a picture of"]).input_ids
|
||||
|
||||
out = hf_model.generate(image, input_ids)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
out = hf_model.generate(image)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
|
||||
model_url = (
|
||||
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
|
||||
)
|
||||
|
||||
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
|
||||
vqa_model.eval()
|
||||
|
||||
modified_state_dict = vqa_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_vqa_model = BlipForQuestionAnswering(config)
|
||||
|
||||
hf_vqa_model.load_state_dict(modified_state_dict)
|
||||
|
||||
question = ["How many dogs are in this image?"]
|
||||
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
|
||||
|
||||
answer = hf_vqa_model.generate(question_input_ids, image)
|
||||
print(tokenizer.decode(answer[0]))
|
||||
|
||||
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
|
||||
|
||||
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
|
||||
itm_model.eval()
|
||||
|
||||
modified_state_dict = itm_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_itm_model = BlipForImageTextRetrieval(config)
|
||||
|
||||
question = ["A picture of a woman with a dog sitting in a beach"]
|
||||
question_input_ids = tokenizer(
|
||||
question,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
).input_ids
|
||||
|
||||
hf_itm_model.load_state_dict(modified_state_dict)
|
||||
hf_itm_model.eval()
|
||||
|
||||
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
|
||||
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
|
||||
|
||||
assert out[0].item() == 0.2110687494277954
|
||||
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
|
@ -1,390 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert BLIP-2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
Blip2Config,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Processor,
|
||||
Blip2QFormerConfig,
|
||||
Blip2VisionConfig,
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, model_name):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||
if "itm" in model_name:
|
||||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name, eos_token_id):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "opt-2.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "opt-6.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "itm" in model_name:
|
||||
text_config = {}
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
if "itm" in model_name:
|
||||
config = Blip2Config(
|
||||
vision_config=vision_config,
|
||||
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||
)
|
||||
else:
|
||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(
|
||||
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
if "opt" in model_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||
elif "itm" in model_name:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
|
||||
if "itm" in model_name:
|
||||
eos_token_id = None
|
||||
else:
|
||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||
|
||||
if "itm" in model_name:
|
||||
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||
else:
|
||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
|
||||
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
|
||||
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
|
||||
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
|
||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config, model_name)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "opt_proj" in key:
|
||||
key = key.replace("opt_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("opt"):
|
||||
key = key.replace("opt", "language")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert len(missing_keys) == 0
|
||||
|
||||
if "itm" in model_name:
|
||||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||
else:
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
|
||||
if "itm" in model_name:
|
||||
caption = "a large fountain spewing water into the air"
|
||||
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
|
||||
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=False,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
else:
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
hf_model.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
"blip2-itm-vit-g",
|
||||
"blip2-itm-vit-g-coco",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="blip2-opt-2.7b",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
parser.add_argument(
|
||||
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(
|
||||
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||
)
|
@ -1,254 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print("Processing file: {}".format(file))
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors.keys():
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
@ -1,145 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Bros checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import bros # original repo
|
||||
import torch
|
||||
|
||||
from transformers import BrosConfig, BrosModel, BrosProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_configs(model_name):
|
||||
bros_config = BrosConfig.from_pretrained(model_name)
|
||||
return bros_config
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"embeddings.bbox_sinusoid_emb.inv_freq",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if name == "embeddings.bbox_projection.weight":
|
||||
name = "bbox_embeddings.bbox_projection.weight"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
# rename keys
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
# remove ignore keys
|
||||
remove_ignore_keys_(orig_state_dict)
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
# load original model
|
||||
original_model = bros.BrosModel.from_pretrained(model_name).eval()
|
||||
|
||||
# load HuggingFace Model
|
||||
bros_config = get_configs(model_name)
|
||||
model = BrosModel.from_pretrained(model_name, config=bros_config)
|
||||
model.eval()
|
||||
|
||||
state_dict = original_model.state_dict()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify results
|
||||
|
||||
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
|
||||
bbox = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
processor = BrosProcessor.from_pretrained(model_name)
|
||||
|
||||
encoding = processor("His name is Rocco.", return_tensors="pt")
|
||||
encoding["bbox"] = bbox
|
||||
|
||||
original_hidden_states = original_model(**encoding).last_hidden_state
|
||||
# pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
last_hidden_states = model(**encoding).last_hidden_state
|
||||
|
||||
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="jinho8345/bros-base-uncased",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Name of the original model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,59 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert T5 checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
@ -1,65 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert CANINE checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
|
||||
# Initialize PyTorch model
|
||||
config = CanineConfig()
|
||||
model = CanineModel(config)
|
||||
model.eval()
|
||||
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_canine(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model (weights and configuration)
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Save tokenizer files
|
||||
tokenizer = CanineTokenizer()
|
||||
print(f"Save tokenizer files to {pytorch_dump_path}")
|
||||
tokenizer.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a folder where the PyTorch model will be placed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
|
@ -1,476 +0,0 @@
|
||||
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate import init_empty_weights
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ChameleonConfig,
|
||||
ChameleonForConditionalGeneration,
|
||||
ChameleonImageProcessor,
|
||||
ChameleonProcessor,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
|
||||
"Update your `tokenizers` library and re-run the tokenizer conversion."
|
||||
)
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
|
||||
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
"""
|
||||
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"30B": 4,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 65536
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
input_model_path = os.path.join(input_base_path, "models", model_size.lower())
|
||||
params_path = os.path.join(input_model_path, "params.json")
|
||||
consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
|
||||
|
||||
params = read_json(params_path)
|
||||
if os.path.isfile(consolidate_params_path):
|
||||
params = {**params, **read_json(consolidate_params_path)}
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
model_parallel_size = params["model_parallel_size"]
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
swin_norm = params["swin_norm"]
|
||||
if base > 10000.0:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
# Depending on the Chameleon version, the default max_position_embeddings has different values.
|
||||
if chameleon_version == 1:
|
||||
max_position_embeddings = 4096
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Version {chameleon_version} of chameleon is not supported yet. "
|
||||
"Current supported versions of chameleon are [1]."
|
||||
)
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
||||
key_value_dim = dim // num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_local_key_value_heads = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = None
|
||||
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
|
||||
possible_path = os.path.join(input_model_path, possible_name)
|
||||
if os.path.exists(possible_path):
|
||||
loaded = torch.load(possible_path, map_location="cpu")
|
||||
break
|
||||
assert loaded is not None
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
# Load weights to the state dict
|
||||
state_dict = {}
|
||||
for layer_i in range(n_layers):
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
|
||||
else:
|
||||
# Sharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
}
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
n_heads=n_heads,
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
)
|
||||
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
)
|
||||
|
||||
# Load VQGAN weights
|
||||
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
|
||||
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"]
|
||||
for k, v in vqgan_state_dict.items():
|
||||
if "decoder" in k:
|
||||
continue # we dont do image generation yet
|
||||
state_dict[f"model.vqmodel.{k}"] = v
|
||||
|
||||
# Write configs
|
||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
|
||||
tokenizer_config = json.load(tokenizer_file)
|
||||
vocabulary_map = tokenizer_config["model"]["vocab"]
|
||||
vocabulary_map["<image>"] = vocabulary_map[
|
||||
"<reserved08707>"
|
||||
] # use a reserved token instead of adding a new one
|
||||
del vocabulary_map["<reserved08707>"]
|
||||
|
||||
for token in tokenizer_config["added_tokens"]:
|
||||
if token["content"] == "<reserved08707>":
|
||||
token["content"] = "<image>"
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f) # save the new file to init tokenizer later
|
||||
|
||||
vq_keys_to_replace = [
|
||||
("ch", "base_channels"),
|
||||
("out_ch", "out_channels"),
|
||||
("n_embed", "num_embeddings"),
|
||||
("ch_mult", "channel_multiplier"),
|
||||
("double_z", "double_latent"),
|
||||
("z_channels", "latent_channels"),
|
||||
]
|
||||
with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
|
||||
vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
|
||||
vq_config.update(**vq_config["ddconfig"])
|
||||
for old, new in vq_keys_to_replace:
|
||||
vq_config[new] = vq_config[old]
|
||||
del vq_config["ddconfig"]
|
||||
del vq_config["ckpt_path"]
|
||||
del vq_config["lossconfig"]
|
||||
|
||||
config = ChameleonConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
model_parallel_size=model_parallel_size,
|
||||
swin_norm=swin_norm,
|
||||
vq_config=vq_config,
|
||||
vocabulary_map=vocabulary_map,
|
||||
)
|
||||
with init_empty_weights():
|
||||
model = ChameleonForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(state_dict, assign=True, strict=False)
|
||||
model.save_pretrained(model_path, safe_serialization=True)
|
||||
|
||||
# Load and save the processor
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
|
||||
)
|
||||
tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
|
||||
tokenizer.pad_token_id = 1 # assing <pad> to special pad_token
|
||||
image_processor = ChameleonImageProcessor()
|
||||
processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
del vqgan_state_dict
|
||||
gc.collect()
|
||||
|
||||
# Short inference on a few examples to check if generation makes sense
|
||||
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
|
||||
print("Loading the checkpoint in a Chameleon model...")
|
||||
print("*" * 100)
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained(model_path)
|
||||
|
||||
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for single-image: {generated_text}")
|
||||
print("*" * 100)
|
||||
|
||||
# Multi-image example
|
||||
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
|
||||
image = Image.open(
|
||||
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
||||
)
|
||||
image_2 = Image.open(
|
||||
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
|
||||
)
|
||||
|
||||
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for multi-image: {generated_text}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of Chameleon weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "30B"],
|
||||
help=""
|
||||
" models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://github.com/facebookresearch/chameleon",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_inference",
|
||||
action="store_true",
|
||||
help="Whether to load the model for generation to test it's converted correctly.",
|
||||
)
|
||||
# Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--chameleon_version",
|
||||
choices=[1],
|
||||
default=1,
|
||||
type=int,
|
||||
help="Version of the Chameleon model to convert",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
model_size=args.model_size,
|
||||
chameleon_version=args.chameleon_version,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -62,6 +62,7 @@ class ChameleonProcessor(ProcessorMixin):
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
valid_kwargs = ["image_seq_length", "image_token"]
|
||||
image_processor_class = "ChameleonImageProcessor"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
|
||||
|
@ -1,134 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ChineseCLIPConfig, ChineseCLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
|
||||
q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
|
||||
out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight.data = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias.data = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_weights, prefix):
|
||||
copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
|
||||
copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_weights, prefix):
|
||||
hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
|
||||
hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_weights, prefix):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
|
||||
copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_weights, prefix):
|
||||
for layer_id, hf_layer in enumerate(hf_layers):
|
||||
copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
|
||||
|
||||
# copy text encoder
|
||||
for name, param in hf_model.text_model.named_parameters():
|
||||
param.data = pt_weights[f"bert.{name}"].data
|
||||
|
||||
|
||||
def copy_vision_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
|
||||
|
||||
# copy embeddings
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
|
||||
hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
|
||||
assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
|
||||
config = ChineseCLIPConfig.from_pretrained(config_path)
|
||||
|
||||
hf_model = ChineseCLIPModel(config).eval()
|
||||
|
||||
pt_weights = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
||||
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_weights)
|
||||
copy_vision_model_and_projection(hf_model, pt_weights)
|
||||
hf_model.logit_scale.data = pt_weights["logit_scale"].data
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output folder storing converted hf PyTorch model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
||||
print("The conversion is finished!")
|
@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
from laion_clap import CLAP_Module
|
||||
|
||||
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"text_branch": "text_model",
|
||||
"audio_branch": "audio_model.audio_encoder",
|
||||
"attn": "attention.self",
|
||||
"self.proj": "output.dense",
|
||||
"attention.self_mask": "attn_mask",
|
||||
"mlp.fc1": "intermediate.dense",
|
||||
"mlp.fc2": "output.dense",
|
||||
"norm1": "layernorm_before",
|
||||
"norm2": "layernorm_after",
|
||||
"bn0": "batch_norm",
|
||||
}
|
||||
|
||||
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
|
||||
|
||||
|
||||
def init_clap(checkpoint_path, model_type, enable_fusion=False):
|
||||
model = CLAP_Module(
|
||||
amodel=model_type,
|
||||
enable_fusion=enable_fusion,
|
||||
)
|
||||
model.load_ckpt(checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
def get_config_from_original(clap_model):
|
||||
audio_config = {
|
||||
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
|
||||
"depths": clap_model.model.audio_branch.depths,
|
||||
"hidden_size": clap_model.model.audio_projection[0].in_features,
|
||||
}
|
||||
|
||||
text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
|
||||
|
||||
return ClapConfig(audio_config=audio_config, text_config=text_config)
|
||||
|
||||
|
||||
def rename_state_dict(state_dict):
|
||||
model_state_dict = {}
|
||||
|
||||
sequential_layers_pattern = r".*sequential.(\d+).*"
|
||||
text_projection_pattern = r".*_projection.(\d+).*"
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# check if any key needs to be modified
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
if re.match(sequential_layers_pattern, key):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
# Because in CLAP they use `nn.Sequential`...
|
||||
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
||||
|
||||
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
||||
|
||||
if "audio" and "qkv" in key:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = value
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
|
||||
model_state_dict[key.replace("qkv", "query")] = query_layer
|
||||
model_state_dict[key.replace("qkv", "key")] = key_layer
|
||||
model_state_dict[key.replace("qkv", "value")] = value_layer
|
||||
else:
|
||||
model_state_dict[key] = value
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
|
||||
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
|
||||
|
||||
clap_model.eval()
|
||||
state_dict = clap_model.model.state_dict()
|
||||
state_dict = rename_state_dict(state_dict)
|
||||
|
||||
transformers_config = get_config_from_original(clap_model)
|
||||
transformers_config.audio_config.enable_fusion = enable_fusion
|
||||
model = ClapModel(transformers_config)
|
||||
|
||||
# ignore the spectrogram embedding layer
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
transformers_config.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
|
||||
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clap_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
|
||||
)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from clip import load
|
||||
|
||||
from transformers import CLIPConfig, CLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_attn_layer.out_proj.weight
|
||||
out_proj_bias = pt_attn_layer.out_proj.bias
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_mlp):
|
||||
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
|
||||
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
|
||||
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_layer.mlp)
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_layers):
|
||||
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
|
||||
copy_layer(hf_layer, pt_layer)
|
||||
|
||||
|
||||
def copy_encoder(hf_encoder, pt_model):
|
||||
# copy embeds
|
||||
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
|
||||
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
|
||||
|
||||
# copy layer norm
|
||||
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
|
||||
|
||||
# copy hidden layers
|
||||
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
|
||||
|
||||
# copy text encoder
|
||||
copy_encoder(hf_model.text_model, pt_model)
|
||||
|
||||
|
||||
def copy_vison_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
|
||||
|
||||
# copy embeds
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
|
||||
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = CLIPConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = CLIPModel(config).eval()
|
||||
|
||||
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_model)
|
||||
copy_vison_model_and_projection(hf_model, pt_model)
|
||||
hf_model.logit_scale = pt_model.logit_scale
|
||||
|
||||
# Use `eos_token` so the example is more meaningful
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[config.text_config.bos_token_id]
|
||||
+ list(range(3, 77))
|
||||
+ [config.text_config.eos_token_id]
|
||||
+ [config.text_config.pad_token_id]
|
||||
]
|
||||
)
|
||||
pixel_values = torch.randn(1, 3, 224, 224)
|
||||
|
||||
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
|
||||
hf_logits_per_image = hf_outputs.logits_per_image
|
||||
hf_logits_per_text = hf_outputs.logits_per_text
|
||||
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
|
||||
|
||||
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
|
||||
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
@ -1,264 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
CLIPSegConfig,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPSegProcessor,
|
||||
CLIPSegTextConfig,
|
||||
CLIPSegVisionConfig,
|
||||
CLIPTokenizer,
|
||||
ViTImageProcessor,
|
||||
)
|
||||
|
||||
|
||||
def get_clipseg_config(model_name):
|
||||
text_config = CLIPSegTextConfig()
|
||||
vision_config = CLIPSegVisionConfig(patch_size=16)
|
||||
|
||||
use_complex_transposed_convolution = True if "refined" in model_name else False
|
||||
reduce_dim = 16 if "rd16" in model_name else 64
|
||||
|
||||
config = CLIPSegConfig.from_text_vision_configs(
|
||||
text_config,
|
||||
vision_config,
|
||||
use_complex_transposed_convolution=use_complex_transposed_convolution,
|
||||
reduce_dim=reduce_dim,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
# update prefixes
|
||||
if "clip_model" in name:
|
||||
name = name.replace("clip_model", "clip")
|
||||
if "transformer" in name:
|
||||
if "visual" in name:
|
||||
name = name.replace("visual.transformer", "vision_model")
|
||||
else:
|
||||
name = name.replace("transformer", "text_model")
|
||||
if "resblocks" in name:
|
||||
name = name.replace("resblocks", "encoder.layers")
|
||||
if "ln_1" in name:
|
||||
name = name.replace("ln_1", "layer_norm1")
|
||||
if "ln_2" in name:
|
||||
name = name.replace("ln_2", "layer_norm2")
|
||||
if "c_fc" in name:
|
||||
name = name.replace("c_fc", "fc1")
|
||||
if "c_proj" in name:
|
||||
name = name.replace("c_proj", "fc2")
|
||||
if "attn" in name and "self" not in name:
|
||||
name = name.replace("attn", "self_attn")
|
||||
# text encoder
|
||||
if "token_embedding" in name:
|
||||
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
|
||||
if "positional_embedding" in name and "visual" not in name:
|
||||
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
|
||||
if "ln_final" in name:
|
||||
name = name.replace("ln_final", "text_model.final_layer_norm")
|
||||
# vision encoder
|
||||
if "visual.class_embedding" in name:
|
||||
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
|
||||
if "visual.conv1" in name:
|
||||
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
|
||||
if "visual.positional_embedding" in name:
|
||||
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
|
||||
if "visual.ln_pre" in name:
|
||||
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
|
||||
if "visual.ln_post" in name:
|
||||
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
|
||||
# projection layers
|
||||
if "visual.proj" in name:
|
||||
name = name.replace("visual.proj", "visual_projection.weight")
|
||||
if "text_projection" in name:
|
||||
name = name.replace("text_projection", "text_projection.weight")
|
||||
# decoder
|
||||
if "trans_conv" in name:
|
||||
name = name.replace("trans_conv", "transposed_convolution")
|
||||
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
|
||||
name = "decoder." + name
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "decoder.layers")
|
||||
if "linear1" in name:
|
||||
name = name.replace("linear1", "mlp.fc1")
|
||||
if "linear2" in name:
|
||||
name = name.replace("linear2", "mlp.fc2")
|
||||
if "norm1" in name and "layer_" not in name:
|
||||
name = name.replace("norm1", "layer_norm1")
|
||||
if "norm2" in name and "layer_" not in name:
|
||||
name = name.replace("norm2", "layer_norm2")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if key.startswith("clip_model") and "attn.in_proj" in key:
|
||||
key_split = key.split(".")
|
||||
if "visual" in key:
|
||||
layer_num = int(key_split[4])
|
||||
dim = config.vision_config.hidden_size
|
||||
prefix = "vision_model"
|
||||
else:
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.text_config.hidden_size
|
||||
prefix = "text_model"
|
||||
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
|
||||
dim : dim * 2, :
|
||||
]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
elif "self_attn" in key and "out_proj" not in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[1])
|
||||
dim = config.reduce_dim
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
else:
|
||||
new_name = rename_key(key)
|
||||
if "visual_projection" in new_name or "text_projection" in new_name:
|
||||
val = val.T
|
||||
orig_state_dict[new_name] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
|
||||
config = get_clipseg_config(model_name)
|
||||
model = CLIPSegForImageSegmentation(config)
|
||||
model.eval()
|
||||
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# remove some keys
|
||||
for key in state_dict.copy().keys():
|
||||
if key.startswith("model"):
|
||||
state_dict.pop(key, None)
|
||||
|
||||
# rename some keys
|
||||
state_dict = convert_state_dict(state_dict, config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
|
||||
raise ValueError("Missing keys that are not expected: {}".format(missing_keys))
|
||||
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
|
||||
raise ValueError(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
image_processor = ViTImageProcessor(size=352)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
image = prepare_img()
|
||||
text = ["a glass", "something to fill", "wood", "a jar"]
|
||||
|
||||
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify values
|
||||
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
|
||||
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
|
||||
if model_name == "clipseg-rd64-refined":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
|
||||
)
|
||||
elif model_name == "clipseg-rd64":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
|
||||
)
|
||||
elif model_name == "clipseg-rd16":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model name {model_name} not supported.")
|
||||
|
||||
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
|
||||
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
|
||||
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and processor for {model_name} to the hub")
|
||||
model.push_to_hub(f"CIDAS/{model_name}")
|
||||
processor.push_to_hub(f"CIDAS/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="clipseg-rd64",
|
||||
type=str,
|
||||
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
|
||||
help=(
|
||||
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
|
||||
" reduce dimension)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
|
||||
" the decoder weights."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,234 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Weights conversion script for CLVP
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
|
||||
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
|
||||
}
|
||||
|
||||
dim = 1024
|
||||
sub_dim = dim // 16
|
||||
|
||||
CLVP_ENCODERS_MAPPING = {
|
||||
"text_transformer.transformer.attn_layers": "text_encoder_model",
|
||||
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
|
||||
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
|
||||
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
|
||||
"to_text_latent": "text_encoder_model.projection",
|
||||
"to_speech_latent": "speech_encoder_model.projection",
|
||||
"text_emb": "text_encoder_model.token_embedding",
|
||||
"speech_emb": "speech_encoder_model.token_embedding",
|
||||
"1.wrap.net.0": "mlp.fc1",
|
||||
"1.wrap.net.3": "mlp.fc2",
|
||||
"1.wrap": "self_attn",
|
||||
"to_out": "out_proj",
|
||||
"to_q": "q_proj",
|
||||
"to_k": "k_proj",
|
||||
"to_v": "v_proj",
|
||||
"temperature": "logit_scale",
|
||||
}
|
||||
|
||||
CLVP_DECODER_MAPPING = {
|
||||
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
|
||||
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
|
||||
"mel_attn_blocks": "group_norms",
|
||||
".norm.weight": ".weight",
|
||||
".norm.bias": ".bias",
|
||||
"text_embedding": "conditioning_encoder.text_token_embedding",
|
||||
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
|
||||
"final_norm": "speech_decoder_model.final_norm",
|
||||
"mel_head": "speech_decoder_model.lm_head",
|
||||
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
|
||||
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
|
||||
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
|
||||
"gpt.h": "speech_decoder_model.model.decoder.layers",
|
||||
"ln_1": "input_layernorm",
|
||||
"ln_2": "post_attention_layernorm",
|
||||
}
|
||||
|
||||
|
||||
def update_index(present_index):
|
||||
if present_index % 2 == 0:
|
||||
return int(present_index / 2)
|
||||
else:
|
||||
return int((present_index - 1) / 2)
|
||||
|
||||
|
||||
def convert_encoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
|
||||
if "0.0.g" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
if int(present_index) % 2 == 0:
|
||||
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
|
||||
else:
|
||||
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
|
||||
|
||||
if "transformer.attn_layers.layers" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
updated_index = update_index(int(present_index))
|
||||
updated_key = updated_key.replace(
|
||||
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
|
||||
)
|
||||
|
||||
for k, v in CLVP_ENCODERS_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def convert_decoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
if len(updated_key.split(".")) > 3:
|
||||
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
|
||||
|
||||
# for decoder attention
|
||||
if "attn.c_attn" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
|
||||
continue
|
||||
|
||||
if "attn.c_proj" in updated_key:
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
|
||||
original_weights[updated_key].squeeze(-1).T
|
||||
)
|
||||
continue
|
||||
|
||||
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
|
||||
original_weights.pop(updated_key)
|
||||
continue
|
||||
|
||||
# conditional encoder attention
|
||||
if "qkv" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
|
||||
indices = torch.arange(dim)
|
||||
index1, index2, index3 = (
|
||||
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
)
|
||||
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index1], slice2[index3], slice3[index2]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index2], slice2[index1], slice3[index3]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index3], slice2[index2], slice3[index1]],
|
||||
axis=0,
|
||||
)
|
||||
continue
|
||||
|
||||
if "proj_out" in updated_key:
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
|
||||
updated_key
|
||||
].squeeze(-1)
|
||||
continue
|
||||
|
||||
for k, v in CLVP_DECODER_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
|
||||
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
force_filename=root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
|
||||
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
|
||||
converted_checkpoint = {}
|
||||
|
||||
for each_model_name, each_model_url in _MODELS.items():
|
||||
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
|
||||
if not os.path.exists(each_model_path):
|
||||
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
|
||||
_download(url=each_model_url, root=each_model_path)
|
||||
|
||||
if each_model_name == "clvp":
|
||||
clvp_checkpoint = torch.load(each_model_path, map_location="cpu")
|
||||
else:
|
||||
decoder_checkpoint = torch.load(each_model_path, map_location="cpu")
|
||||
|
||||
# Converting the weights
|
||||
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
|
||||
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
|
||||
|
||||
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
|
||||
model = ClvpModelForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(converted_checkpoint, strict=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Model saved at {pytorch_dump_folder_path}!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model. (Please enter full path)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,214 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert ColPali weights from the original repository to the HF model format.
|
||||
|
||||
Original repository: https://github.com/illuin-tech/colpali.
|
||||
|
||||
NOTE: This script was originally run using `torch==2.5.1` and with:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf-internal \
|
||||
--push_to_hub
|
||||
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.3-merged \
|
||||
--revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.3-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.colpali import ColPaliForRetrieval
|
||||
from transformers.models.colpali.configuration_colpali import ColPaliConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
if key.startswith("custom_text_proj"):
|
||||
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
|
||||
if key.startswith("model."):
|
||||
new_key = key.replace("model.", "vlm.", 1)
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
|
||||
"model.language_model.model.embed_tokens.weight"
|
||||
].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_colpali_weights_to_hf(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
push_to_hub: bool,
|
||||
revision: Optional[str] = None,
|
||||
original_vlm_name_or_path: Optional[str] = None,
|
||||
):
|
||||
# Load the original model data
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
if original_vlm_name_or_path is not None:
|
||||
original_config._name_or_path = original_vlm_name_or_path
|
||||
if hasattr(original_config, "architectures"):
|
||||
delattr(original_config, "architectures")
|
||||
|
||||
original_state_dict = load_original_state_dict(model_id, revision=revision)
|
||||
|
||||
# Format the state_dict keys
|
||||
original_state_dict = rename_state_dict_keys(original_state_dict)
|
||||
|
||||
# Create the new config
|
||||
config = ColPaliConfig(
|
||||
vlm_config=original_config,
|
||||
embedding_dim=128, # hardcoded in the original model
|
||||
)
|
||||
config.model_type = "colpali"
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColPaliForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
# There are two ways to set the model's dtype:
|
||||
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
|
||||
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
|
||||
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
|
||||
# the new weights' dtypes match the original model.
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(ORIGINAL_DTYPE)
|
||||
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
|
||||
|
||||
# Load the original weights
|
||||
model.load_state_dict(original_state_dict)
|
||||
print("Loaded original model weights")
|
||||
|
||||
# Tie the weights (following ColPali's `__init__`` step)
|
||||
if model.vlm.language_model._tied_weights_keys is not None:
|
||||
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
|
||||
|
||||
# Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
state_dict_keys_new = set(model.state_dict().keys())
|
||||
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
|
||||
if disjoint_keys:
|
||||
raise ValueError(f"Incompatible keys: {disjoint_keys}")
|
||||
|
||||
# Save the model
|
||||
if push_to_hub:
|
||||
model.push_to_hub(output_dir, private=True)
|
||||
print(f"Model pushed to the hub at `{output_dir}`")
|
||||
else:
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
model.save_pretrained(output_dir)
|
||||
print(f"Model saved to `{output_dir}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script converts the original ColPali model to the HF model format.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Model ID of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Revision of the model to download",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_vlm_name_or_path",
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colpali_weights_to_hf(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
revision=args.revision,
|
||||
original_vlm_name_or_path=args.original_vlm_name_or_path,
|
||||
)
|
@ -1,324 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Conditional DETR checkpoints."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ConditionalDetrConfig,
|
||||
ConditionalDetrForObjectDetection,
|
||||
ConditionalDetrForSegmentation,
|
||||
ConditionalDetrImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
rename_keys = []
|
||||
for i in range(6):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
||||
|
||||
# q, k, v projections in self/cross-attention in decoder for conditional DETR
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
|
||||
)
|
||||
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
|
||||
)
|
||||
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
# for conditional DETR, also convert reference point head and query scale MLP
|
||||
rename_keys.extend(
|
||||
[
|
||||
("input_proj.weight", "input_projection.weight"),
|
||||
("input_proj.bias", "input_projection.bias"),
|
||||
("query_embed.weight", "query_position_embeddings.weight"),
|
||||
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
||||
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
||||
("class_embed.weight", "class_labels_classifier.weight"),
|
||||
("class_embed.bias", "class_labels_classifier.bias"),
|
||||
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
||||
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
||||
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
||||
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
||||
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
||||
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
|
||||
("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
|
||||
("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
|
||||
("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
|
||||
("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def rename_key(state_dict, old, new):
|
||||
val = state_dict.pop(old)
|
||||
state_dict[new] = val
|
||||
|
||||
|
||||
def rename_backbone_keys(state_dict):
|
||||
new_state_dict = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
if "backbone.0.body" in key:
|
||||
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
new_state_dict[key] = value
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def read_in_q_k_v(state_dict, is_panoptic=False):
|
||||
prefix = ""
|
||||
if is_panoptic:
|
||||
prefix = "conditional_detr."
|
||||
|
||||
# first: transformer encoder
|
||||
for i in range(6):
|
||||
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
|
||||
"""
|
||||
|
||||
# load default config
|
||||
config = ConditionalDetrConfig()
|
||||
# set backbone and dilation attributes
|
||||
if "resnet101" in model_name:
|
||||
config.backbone = "resnet101"
|
||||
if "dc5" in model_name:
|
||||
config.dilation = True
|
||||
is_panoptic = "panoptic" in model_name
|
||||
if is_panoptic:
|
||||
config.num_labels = 250
|
||||
else:
|
||||
config.num_labels = 91
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# load image processor
|
||||
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
||||
image_processor = ConditionalDetrImageProcessor(format=format)
|
||||
|
||||
# prepare image
|
||||
img = prepare_img()
|
||||
encoding = image_processor(images=img, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
|
||||
# load original model from torch hub
|
||||
conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
|
||||
state_dict = conditional_detr.state_dict()
|
||||
# rename keys
|
||||
for src, dest in rename_keys:
|
||||
if is_panoptic:
|
||||
src = "conditional_detr." + src
|
||||
rename_key(state_dict, src, dest)
|
||||
state_dict = rename_backbone_keys(state_dict)
|
||||
# query, key and value matrices need special treatment
|
||||
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
prefix = "conditional_detr.model." if is_panoptic else "model."
|
||||
for key in state_dict.copy().keys():
|
||||
if is_panoptic:
|
||||
if (
|
||||
key.startswith("conditional_detr")
|
||||
and not key.startswith("class_labels_classifier")
|
||||
and not key.startswith("bbox_predictor")
|
||||
):
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr.model" + key[4:]] = val
|
||||
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr." + key] = val
|
||||
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
||||
continue
|
||||
else:
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
else:
|
||||
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
|
||||
# verify our conversion
|
||||
original_outputs = conditional_detr(pixel_values)
|
||||
outputs = model(pixel_values)
|
||||
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
|
||||
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
|
||||
if is_panoptic:
|
||||
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
||||
|
||||
# Save model and image processor
|
||||
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="conditional_detr_resnet50",
|
||||
type=str,
|
||||
help="Name of the CONDITIONAL_DETR model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
@ -1,57 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ConvBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
|
||||
conf = ConvBertConfig.from_json_file(convbert_config_file)
|
||||
model = ConvBertModel(conf)
|
||||
|
||||
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
|
||||
tf_model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--convbert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ConvBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)
|
@ -1,242 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ConvNext checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/facebookresearch/ConvNeXt"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_convnext_config(checkpoint_url):
|
||||
config = ConvNextConfig()
|
||||
|
||||
if "tiny" in checkpoint_url:
|
||||
depths = [3, 3, 9, 3]
|
||||
hidden_sizes = [96, 192, 384, 768]
|
||||
if "small" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [96, 192, 384, 768]
|
||||
if "base" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [128, 256, 512, 1024]
|
||||
if "large" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [192, 384, 768, 1536]
|
||||
if "xlarge" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [256, 512, 1024, 2048]
|
||||
|
||||
if "1k" in checkpoint_url:
|
||||
num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
expected_shape = (1, 1000)
|
||||
else:
|
||||
num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
expected_shape = (1, 21841)
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
config.num_labels = num_labels
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
if "1k" not in checkpoint_url:
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.hidden_sizes = hidden_sizes
|
||||
config.depths = depths
|
||||
|
||||
return config, expected_shape
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "downsample_layers.0.0" in name:
|
||||
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
|
||||
if "downsample_layers.0.1" in name:
|
||||
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
|
||||
if "downsample_layers.1.0" in name:
|
||||
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
|
||||
if "downsample_layers.1.1" in name:
|
||||
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
|
||||
if "downsample_layers.2.0" in name:
|
||||
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
|
||||
if "downsample_layers.2.1" in name:
|
||||
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
|
||||
if "downsample_layers.3.0" in name:
|
||||
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
|
||||
if "downsample_layers.3.1" in name:
|
||||
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
|
||||
if "stages" in name and "downsampling_layer" not in name:
|
||||
# stages.0.0. for instance should be renamed to stages.0.layers.0.
|
||||
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
|
||||
if "stages" in name:
|
||||
name = name.replace("stages", "encoder.stages")
|
||||
if "norm" in name:
|
||||
name = name.replace("norm", "layernorm")
|
||||
if "gamma" in name:
|
||||
name = name.replace("gamma", "layer_scale_parameter")
|
||||
if "head" in name:
|
||||
name = name.replace("head", "classifier")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ConvNext structure.
|
||||
"""
|
||||
|
||||
# define ConvNext configuration based on URL
|
||||
config, expected_shape = get_convnext_config(checkpoint_url)
|
||||
# load original state_dict from URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
|
||||
# rename keys
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val
|
||||
# add prefix to all keys expect classifier head
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
if not key.startswith("classifier"):
|
||||
key = "convnext." + key
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
model = ConvNextForImageClassification(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# Check outputs on an image, prepared by ConvNextImageProcessor
|
||||
size = 224 if "224" in checkpoint_url else 384
|
||||
image_processor = ConvNextImageProcessor(size=size)
|
||||
pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values
|
||||
|
||||
logits = model(pixel_values).logits
|
||||
|
||||
# note: the logits below were obtained without center cropping
|
||||
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([0.4525, 0.7539, 0.0308])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth":
|
||||
expected_logits = torch.tensor([0.3561, 0.6350, -0.0384])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([0.4174, -0.0989, 0.1489])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth":
|
||||
expected_logits = torch.tensor([0.2513, -0.1349, -0.1613])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.2980, 0.3631, -0.1198])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.2963, 0.1227, 0.1723])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.7956, 0.8390, 0.2820])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth":
|
||||
expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth":
|
||||
expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth":
|
||||
expected_logits = torch.tensor([0.2681, 0.2365, 0.6246])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth":
|
||||
expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth":
|
||||
expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444])
|
||||
else:
|
||||
raise ValueError(f"Unknown URL: {checkpoint_url}")
|
||||
|
||||
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3)
|
||||
assert logits.shape == expected_shape
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
print("Pushing model to the hub...")
|
||||
model_name = "convnext"
|
||||
if "tiny" in checkpoint_url:
|
||||
model_name += "-tiny"
|
||||
elif "small" in checkpoint_url:
|
||||
model_name += "-small"
|
||||
elif "base" in checkpoint_url:
|
||||
model_name += "-base"
|
||||
elif "xlarge" in checkpoint_url:
|
||||
model_name += "-xlarge"
|
||||
elif "large" in checkpoint_url:
|
||||
model_name += "-large"
|
||||
if "224" in checkpoint_url:
|
||||
model_name += "-224"
|
||||
elif "384" in checkpoint_url:
|
||||
model_name += "-384"
|
||||
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
|
||||
model_name += "-22k"
|
||||
if "22k" in checkpoint_url and "1k" in checkpoint_url:
|
||||
model_name += "-22k-1k"
|
||||
|
||||
model.push_to_hub(
|
||||
repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
|
||||
organization="nielsr",
|
||||
commit_message="Add model",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
type=str,
|
||||
help="URL of the original ConvNeXT checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user