mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
26 Commits
reference_
...
v4.48.3
Author | SHA1 | Date | |
---|---|---|---|
298b3f1930 | |||
d28f0207d5 | |||
3d6e55c7e7 | |||
093bebcdd9 | |||
97a6cf9072 | |||
11e31ec24f | |||
b673c16cad | |||
aa3e590100 | |||
f3fad5755a | |||
e5f88ae076 | |||
163c8bbdc9 | |||
b17abf9519 | |||
f7b6047a4e | |||
2e752ead46 | |||
785b5cf444 | |||
3b09464364 | |||
b00807fac2 | |||
612bfd0801 | |||
6bc0fbcfa7 | |||
59e28c30fa | |||
7cf6230e25 | |||
d6f446ffa7 | |||
8ce1e9578a | |||
af2d7caff3 | |||
42b8e7916b | |||
e39c9f7a78 |
@ -505,7 +505,7 @@
|
||||
- local: model_doc/mobilebert
|
||||
title: MobileBERT
|
||||
- local: model_doc/modernbert
|
||||
title: ModernBERT
|
||||
title: ModernBert
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mpt
|
||||
@ -768,6 +768,8 @@
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/moonshine
|
||||
title: Moonshine
|
||||
- local: model_doc/moshi
|
||||
title: Moshi
|
||||
- local: model_doc/musicgen
|
||||
@ -858,6 +860,8 @@
|
||||
title: DePlot
|
||||
- local: model_doc/donut
|
||||
title: Donut
|
||||
- local: model_doc/emu3
|
||||
title: Emu3
|
||||
- local: model_doc/flava
|
||||
title: FLAVA
|
||||
- local: model_doc/git
|
||||
|
@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [EfficientFormer](model_doc/efficientformer) | ✅ | ✅ | ❌ |
|
||||
| [EfficientNet](model_doc/efficientnet) | ✅ | ❌ | ❌ |
|
||||
| [ELECTRA](model_doc/electra) | ✅ | ✅ | ✅ |
|
||||
| [Emu3](model_doc/emu3) | ✅ | ❌ | ❌ |
|
||||
| [EnCodec](model_doc/encodec) | ✅ | ❌ | ❌ |
|
||||
| [Encoder decoder](model_doc/encoder-decoder) | ✅ | ✅ | ✅ |
|
||||
| [ERNIE](model_doc/ernie) | ✅ | ❌ | ❌ |
|
||||
@ -235,6 +236,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
|
||||
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
|
||||
| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ |
|
||||
| [Moonshine](model_doc/moonshine) | ✅ | ❌ | ❌ |
|
||||
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
|
||||
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
|
||||
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |
|
||||
|
179
docs/source/en/model_doc/emu3.md
Normal file
179
docs/source/en/model_doc/emu3.md
Normal file
@ -0,0 +1,179 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Emu3
|
||||
|
||||
## Overview
|
||||
|
||||
The Emu3 model was proposed in [Emu3: Next-Token Prediction is All You Need](https://arxiv.org/abs/2409.18869) by Xinlong Wang, Xiaosong Zhang, Zhengxiong Luo, Quan Sun, Yufeng Cui, Jinsheng Wang, Fan Zhang, Yueze Wang, Zhen Li, Qiying Yu, Yingli Zhao, Yulong Ao, Xuebin Min, Tao Li, Boya Wu, Bo Zhao, Bowen Zhang, Liangdong Wang, Guang Liu, Zheqi He, Xi Yang, Jingjing Liu, Yonghua Lin, Tiejun Huang, Zhongyuan Wang.
|
||||
|
||||
Emu3 is a multimodal LLM that uses vector quantization to tokenize images into discrete tokens. Discretized image tokens are later fused with text token ids for image and text generation. The model can additionally generate images by predicting image token ids.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While next-token prediction is considered a promising path towards artificial general intelligence, it has struggled to excel in multimodal tasks, which are still dominated by diffusion models (e.g., Stable Diffusion) and compositional approaches (e.g., CLIP combined with LLMs). In this paper, we introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction. By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences. Emu3 outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship models such as SDXL and LLaVA-1.6, while eliminating the need for diffusion or compositional architectures. Emu3 is also capable of generating high-fidelity video via predicting the next token in a video sequence. We simplify complex multimodal model designs by converging on a singular focus: tokens, unlocking great potential for scaling both during training and inference. Our results demonstrate that next-token prediction is a promising path towards building general multimodal intelligence beyond language. We open-source key techniques and models to support further research in this direction.*
|
||||
|
||||
Tips:
|
||||
|
||||
- We advise users to set `processor.tokenizer.padding_side = "left"` before batched generation as it leads to more accurate results.
|
||||
|
||||
- Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
|
||||
|
||||
- Emu3 has two different checkpoints for image-generation and text-generation, make sure to use the correct checkpoint when loading the model. To generate an image, it is advised to use `prefix_constraints` so that the generated tokens are sampled only from possible image tokens. See more below for usage examples.
|
||||
|
||||
> [!TIP]
|
||||
> Emu3 implementation in Transformers uses a special image token to indicate where to merge image embeddings. The special image token isn't new and uses one of the reserved tokens: `<|extra_0|>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.
|
||||
|
||||
|
||||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||||
The original code can be found [here](https://github.com/baaivision/Emu3).
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
### Text generation inference
|
||||
|
||||
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`) to generate textual output from text or text and image inputs:
|
||||
|
||||
```python
|
||||
from transformers import Emu3Processor, Emu3ForConditionalGeneration
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf")
|
||||
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
|
||||
# prepare image and text prompt
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "What do you see in this image?<image>"
|
||||
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
# autoregressively complete prompt
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Image generation inference
|
||||
|
||||
Emu3 can also generate images from textual input. Here is how you can do it:
|
||||
|
||||
```python
|
||||
processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Gen-hf")
|
||||
model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Gen-hf", torch_dtype="bfloat16", device_map="auto", attn_implementation="flash_attention_2")
|
||||
|
||||
|
||||
inputs = processor(
|
||||
text=["a portrait of young girl. masterpiece, film grained, best quality.", "a dog running under the rain"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_for_image_generation=True,
|
||||
)
|
||||
inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
|
||||
|
||||
neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
|
||||
neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
|
||||
|
||||
image_sizes = inputs.pop("image_sizes")
|
||||
HEIGHT, WIDTH = image_sizes[0]
|
||||
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, input_ids):
|
||||
height, width = HEIGHT, WIDTH
|
||||
visual_tokens = VISUAL_TOKENS
|
||||
image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device)
|
||||
eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device)
|
||||
eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device)
|
||||
pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device)
|
||||
eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device)
|
||||
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
|
||||
|
||||
position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0]
|
||||
offset = input_ids.shape[0] - position
|
||||
if offset % (width + 1) == 0:
|
||||
return (eol_token_id, )
|
||||
elif offset == (width + 1) * height + 1:
|
||||
return (eof_token_id, )
|
||||
elif offset == (width + 1) * height + 2:
|
||||
return (eoi_token_id, )
|
||||
elif offset == (width + 1) * height + 3:
|
||||
return (eos_token_id, )
|
||||
elif offset > (width + 1) * height + 3:
|
||||
return (pad_token_id, )
|
||||
else:
|
||||
return visual_tokens
|
||||
|
||||
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50_000, # make sure to have enough tokens for one image
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
return_dict_in_generate=True,
|
||||
negative_prompt_ids=neg_inputs.input_ids, # indicate for Classifier-Free Guidance
|
||||
negative_prompt_attention_mask=neg_inputs.attention_mask,
|
||||
)
|
||||
|
||||
image = model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH)
|
||||
images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image") # internally we convert to np but it's not supported in bf16 precision
|
||||
for i, image in enumerate(images['pixel_values']):
|
||||
image.save(f"result{i}.png")
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Emu3Config
|
||||
|
||||
[[autodoc]] Emu3Config
|
||||
|
||||
## Emu3VQVAEConfig
|
||||
|
||||
[[autodoc]] Emu3VQVAEConfig
|
||||
|
||||
## Emu3TextConfig
|
||||
|
||||
[[autodoc]] Emu3TextConfig
|
||||
|
||||
## Emu3Processor
|
||||
|
||||
[[autodoc]] Emu3Processor
|
||||
|
||||
## Emu3ImageProcessor
|
||||
|
||||
[[autodoc]] Emu3ImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Emu3VQVAE
|
||||
|
||||
[[autodoc]] Emu3VQVAE
|
||||
- forward
|
||||
|
||||
## Emu3TextModel
|
||||
|
||||
[[autodoc]] Emu3TextModel
|
||||
- forward
|
||||
|
||||
## Emu3ForCausalLM
|
||||
|
||||
[[autodoc]] Emu3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Emu3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Emu3ForConditionalGeneration
|
||||
- forward
|
56
docs/source/en/model_doc/moonshine.md
Normal file
56
docs/source/en/model_doc/moonshine.md
Normal file
@ -0,0 +1,56 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Moonshine
|
||||
|
||||
## Overview
|
||||
|
||||
The Moonshine model was proposed in [Moonshine: Speech Recognition for Live Transcription and Voice Commands
|
||||
](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper introduces Moonshine, a family of speech recognition models optimized for live transcription and voice command processing. Moonshine is based on an encoder-decoder transformer architecture and employs Rotary Position Embedding (RoPE) instead of traditional absolute position embeddings. The model is trained on speech segments of various lengths, but without using zero-padding, leading to greater efficiency for the encoder during inference time. When benchmarked against OpenAI's Whisper tiny-en, Moonshine Tiny demonstrates a 5x reduction in compute requirements for transcribing a 10-second speech segment while incurring no increase in word error rates across standard evaluation datasets. These results highlight Moonshine's potential for real-time and resource-constrained applications.*
|
||||
|
||||
Tips:
|
||||
|
||||
- Moonshine improves upon Whisper's architecture:
|
||||
1. It uses SwiGLU activation instead of GELU in the decoder layers
|
||||
2. Most importantly, it replaces absolute position embeddings with Rotary Position Embeddings (RoPE). This allows Moonshine to handle audio inputs of any length, unlike Whisper which is restricted to fixed 30-second windows.
|
||||
|
||||
This model was contributed by [Eustache Le Bihan (eustlb)](https://huggingface.co/eustlb).
|
||||
The original code can be found [here](https://github.com/usefulsensors/moonshine).
|
||||
|
||||
## Resources
|
||||
|
||||
- [Automatic speech recognition task guide](../tasks/asr)
|
||||
|
||||
## MoonshineConfig
|
||||
|
||||
[[autodoc]] MoonshineConfig
|
||||
|
||||
## MoonshineModel
|
||||
|
||||
[[autodoc]] MoonshineModel
|
||||
- forward
|
||||
- _mask_input_features
|
||||
|
||||
## MoonshineForConditionalGeneration
|
||||
|
||||
[[autodoc]] MoonshineForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
|
@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
@ -68,6 +69,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
|
||||
@ -244,6 +246,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
|
||||
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
@ -265,6 +268,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
|
||||
@ -283,8 +287,8 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
|
||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.48.0.dev0")
|
||||
check_min_version("4.48.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -437,7 +437,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.48.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.48.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.48.0.dev0"
|
||||
__version__ = "4.48.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -428,6 +428,12 @@ _import_structure = {
|
||||
"ElectraConfig",
|
||||
"ElectraTokenizer",
|
||||
],
|
||||
"models.emu3": [
|
||||
"Emu3Config",
|
||||
"Emu3Processor",
|
||||
"Emu3TextConfig",
|
||||
"Emu3VQVAEConfig",
|
||||
],
|
||||
"models.encodec": [
|
||||
"EncodecConfig",
|
||||
"EncodecFeatureExtractor",
|
||||
@ -610,6 +616,7 @@ _import_structure = {
|
||||
"models.mobilevit": ["MobileViTConfig"],
|
||||
"models.mobilevitv2": ["MobileViTV2Config"],
|
||||
"models.modernbert": ["ModernBertConfig"],
|
||||
"models.moonshine": ["MoonshineConfig"],
|
||||
"models.moshi": [
|
||||
"MoshiConfig",
|
||||
"MoshiDepthConfig",
|
||||
@ -1221,6 +1228,7 @@ else:
|
||||
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
|
||||
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
|
||||
_import_structure["models.efficientnet"].append("EfficientNetImageProcessor")
|
||||
_import_structure["models.emu3"].append("Emu3ImageProcessor")
|
||||
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
|
||||
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
|
||||
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
|
||||
@ -2242,6 +2250,15 @@ else:
|
||||
"load_tf_weights_in_electra",
|
||||
]
|
||||
)
|
||||
_import_structure["models.emu3"].extend(
|
||||
[
|
||||
"Emu3ForCausalLM",
|
||||
"Emu3ForConditionalGeneration",
|
||||
"Emu3PreTrainedModel",
|
||||
"Emu3TextModel",
|
||||
"Emu3VQVAE",
|
||||
]
|
||||
)
|
||||
_import_structure["models.encodec"].extend(
|
||||
[
|
||||
"EncodecModel",
|
||||
@ -2907,6 +2924,13 @@ else:
|
||||
"ModernBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moonshine"].extend(
|
||||
[
|
||||
"MoonshineForConditionalGeneration",
|
||||
"MoonshineModel",
|
||||
"MoonshinePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.moshi"].extend(
|
||||
[
|
||||
"MoshiForCausalLM",
|
||||
@ -5432,6 +5456,12 @@ if TYPE_CHECKING:
|
||||
ElectraConfig,
|
||||
ElectraTokenizer,
|
||||
)
|
||||
from .models.emu3 import (
|
||||
Emu3Config,
|
||||
Emu3Processor,
|
||||
Emu3TextConfig,
|
||||
Emu3VQVAEConfig,
|
||||
)
|
||||
from .models.encodec import (
|
||||
EncodecConfig,
|
||||
EncodecFeatureExtractor,
|
||||
@ -5633,6 +5663,7 @@ if TYPE_CHECKING:
|
||||
MobileViTV2Config,
|
||||
)
|
||||
from .models.modernbert import ModernBertConfig
|
||||
from .models.moonshine import MoonshineConfig
|
||||
from .models.moshi import (
|
||||
MoshiConfig,
|
||||
MoshiDepthConfig,
|
||||
@ -6261,6 +6292,7 @@ if TYPE_CHECKING:
|
||||
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
|
||||
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
|
||||
from .models.efficientnet import EfficientNetImageProcessor
|
||||
from .models.emu3 import Emu3ImageProcessor
|
||||
from .models.flava import (
|
||||
FlavaFeatureExtractor,
|
||||
FlavaImageProcessor,
|
||||
@ -7130,6 +7162,13 @@ if TYPE_CHECKING:
|
||||
ElectraPreTrainedModel,
|
||||
load_tf_weights_in_electra,
|
||||
)
|
||||
from .models.emu3 import (
|
||||
Emu3ForCausalLM,
|
||||
Emu3ForConditionalGeneration,
|
||||
Emu3PreTrainedModel,
|
||||
Emu3TextModel,
|
||||
Emu3VQVAE,
|
||||
)
|
||||
from .models.encodec import (
|
||||
EncodecModel,
|
||||
EncodecPreTrainedModel,
|
||||
@ -7652,6 +7691,11 @@ if TYPE_CHECKING:
|
||||
ModernBertModel,
|
||||
ModernBertPreTrainedModel,
|
||||
)
|
||||
from .models.moonshine import (
|
||||
MoonshineForConditionalGeneration,
|
||||
MoonshineModel,
|
||||
MoonshinePreTrainedModel,
|
||||
)
|
||||
from .models.moshi import (
|
||||
MoshiForCausalLM,
|
||||
MoshiForConditionalGeneration,
|
||||
|
@ -249,7 +249,7 @@ def squad_convert_example_to_features(
|
||||
else:
|
||||
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
||||
|
||||
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
|
||||
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
|
||||
special_token_indices = np.asarray(
|
||||
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
||||
).nonzero()
|
||||
|
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
@ -1634,17 +1634,18 @@ class GenerationMixin:
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
|
||||
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
for idx in range(num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
return layer_device_map
|
||||
|
@ -27,7 +27,7 @@ def flex_attention_forward(
|
||||
if softcap is not None:
|
||||
score = softcap * torch.tanh(score / softcap)
|
||||
if causal_mask is not None:
|
||||
score += causal_mask[b][0][q_idx][kv_idx]
|
||||
score = score + causal_mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output, attention_weights = flex_attention(
|
||||
|
@ -45,6 +45,11 @@ def sdpa_attention_forward(
|
||||
if is_causal is None:
|
||||
is_causal = causal_mask is None and query.shape[2] > 1
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
@ -4020,10 +4020,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
|
||||
)
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, dict):
|
||||
for key, curr_dtype in torch_dtype.items():
|
||||
if hasattr(config, key):
|
||||
value = getattr(config, key)
|
||||
value.torch_dtype = curr_dtype
|
||||
# main torch dtype for modules that aren't part of any sub-config
|
||||
torch_dtype = torch_dtype.get("")
|
||||
config.torch_dtype = torch_dtype
|
||||
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
elif torch_dtype is None:
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
|
||||
f"for each sub-config in composite configs, but received {torch_dtype}"
|
||||
)
|
||||
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
@ -5121,6 +5142,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
if hasattr(self, "_loss_function"):
|
||||
return self._loss_function
|
||||
|
||||
loss_type = getattr(self, "loss_type", None)
|
||||
|
||||
if loss_type is None or loss_type not in LOSS_MAPPING:
|
||||
@ -5131,6 +5155,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
loss_type = "ForCausalLM"
|
||||
return LOSS_MAPPING[loss_type]
|
||||
|
||||
@loss_function.setter
|
||||
def loss_function(self, value):
|
||||
self._loss_function = value
|
||||
|
||||
def get_compiled_call(self, compile_config: CompileConfig):
|
||||
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
||||
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
||||
|
@ -86,6 +86,7 @@ from . import (
|
||||
dpt,
|
||||
efficientnet,
|
||||
electra,
|
||||
emu3,
|
||||
encodec,
|
||||
encoder_decoder,
|
||||
ernie,
|
||||
@ -170,6 +171,7 @@ from . import (
|
||||
mobilevit,
|
||||
mobilevitv2,
|
||||
modernbert,
|
||||
moonshine,
|
||||
moshi,
|
||||
mpnet,
|
||||
mpt,
|
||||
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = AlbertConfig.from_json_file(albert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = AlbertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
@ -1,389 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALIGN checkpoints from the original repository."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import align
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from transformers import (
|
||||
AlignConfig,
|
||||
AlignModel,
|
||||
AlignProcessor,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
EfficientNetConfig,
|
||||
EfficientNetImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
image = tf.image.resize(image, (346, 346))
|
||||
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
|
||||
return image
|
||||
|
||||
|
||||
def get_align_config():
|
||||
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
|
||||
vision_config.image_size = 289
|
||||
vision_config.hidden_dim = 640
|
||||
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
|
||||
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
|
||||
vision_config.depthwise_padding = []
|
||||
|
||||
text_config = BertConfig()
|
||||
config = AlignConfig.from_text_vision_configs(
|
||||
text_config=text_config, vision_config=vision_config, projection_dim=640
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def get_processor():
|
||||
image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
rescale_offset=True,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
tokenizer.model_max_length = 64
|
||||
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def rename_keys(original_param_names):
|
||||
# EfficientNet image encoder
|
||||
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
|
||||
block_names = list(set(block_names))
|
||||
block_names = sorted(block_names)
|
||||
num_blocks = len(block_names)
|
||||
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
|
||||
|
||||
rename_keys = []
|
||||
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
|
||||
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
|
||||
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
|
||||
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
|
||||
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
|
||||
|
||||
for b in block_names:
|
||||
hf_b = block_name_mapping[b]
|
||||
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
|
||||
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
|
||||
)
|
||||
|
||||
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
|
||||
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
|
||||
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
|
||||
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
|
||||
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
|
||||
)
|
||||
|
||||
key_mapping = {}
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = "vision_model." + item[1]
|
||||
|
||||
# BERT text encoder
|
||||
rename_keys = []
|
||||
old = "tf_bert_model/bert"
|
||||
new = "text_model"
|
||||
for i in range(12):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
|
||||
)
|
||||
|
||||
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
|
||||
)
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
|
||||
|
||||
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
|
||||
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
|
||||
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("temperature:0", "temperature"))
|
||||
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = item[1]
|
||||
return key_mapping
|
||||
|
||||
|
||||
def replace_params(hf_params, tf_params, key_mapping):
|
||||
list(hf_params.keys())
|
||||
|
||||
for key, value in tf_params.items():
|
||||
if key not in key_mapping:
|
||||
continue
|
||||
|
||||
hf_key = key_mapping[key]
|
||||
if "_conv" in key and "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
|
||||
elif "embeddings" in key:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
elif "depthwise_kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
|
||||
elif "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
# Replace HF parameters with original TF model parameters
|
||||
hf_params[hf_key].copy_(new_hf_value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ALIGN structure.
|
||||
"""
|
||||
# Load original model
|
||||
seq_length = 64
|
||||
tok = Tokenizer(seq_length)
|
||||
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
|
||||
original_model.compile()
|
||||
original_model.load_weights(checkpoint_path)
|
||||
|
||||
tf_params = original_model.trainable_variables
|
||||
tf_non_train_params = original_model.non_trainable_variables
|
||||
tf_params = {param.name: param.numpy() for param in tf_params}
|
||||
for param in tf_non_train_params:
|
||||
tf_params[param.name] = param.numpy()
|
||||
tf_param_names = list(tf_params.keys())
|
||||
|
||||
# Load HuggingFace model
|
||||
config = get_align_config()
|
||||
hf_model = AlignModel(config).eval()
|
||||
hf_params = hf_model.state_dict()
|
||||
|
||||
# Create src-to-dst parameter name mapping dictionary
|
||||
print("Converting parameters...")
|
||||
key_mapping = rename_keys(tf_param_names)
|
||||
replace_params(hf_params, tf_params, key_mapping)
|
||||
|
||||
# Initialize processor
|
||||
processor = get_processor()
|
||||
inputs = processor(
|
||||
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
|
||||
)
|
||||
|
||||
# HF model inference
|
||||
hf_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
|
||||
hf_image_features = outputs.image_embeds.detach().numpy()
|
||||
hf_text_features = outputs.text_embeds.detach().numpy()
|
||||
|
||||
# Original model inference
|
||||
original_model.trainable = False
|
||||
tf_image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
|
||||
text = tok(tf.constant(["A picture of a cat"]))
|
||||
|
||||
image_features = original_model.image_encoder(image, training=False)
|
||||
text_features = original_model.text_encoder(text, training=False)
|
||||
|
||||
image_features = tf.nn.l2_normalize(image_features, axis=-1)
|
||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||
|
||||
# Check whether original and HF model outputs match -> np.allclose
|
||||
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||
raise ValueError("The predicted image features are not the same.")
|
||||
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||
raise ValueError("The predicted text features are not the same.")
|
||||
print("Model outputs match!")
|
||||
|
||||
if save_model:
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
# Save converted model and image processor
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Push model and image processor to hub
|
||||
print("Pushing converted ALIGN to the hub...")
|
||||
processor.push_to_hub("align-base")
|
||||
hf_model.push_to_hub("align-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="./weights/model-weights",
|
||||
type=str,
|
||||
help="Path to the pretrained TF ALIGN checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="hf_model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EPILOG_TXT = """Example:
|
||||
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
|
||||
|
||||
Example for creating the old state dict file with Python:
|
||||
|
||||
import torch
|
||||
from aria.model.language_model.aria_llama import AriaTextForCausalLM
|
||||
|
||||
# load model
|
||||
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
|
||||
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
# load vision tower
|
||||
model.get_vision_tower().load_model()
|
||||
|
||||
# Save state dict
|
||||
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
|
||||
"""
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"vision_tower.vision_model": "vision_tower",
|
||||
"ln_ffn": "layer_norm",
|
||||
"ffn": "feed_forward",
|
||||
"ln_kv": "layer_norm_kv",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
|
||||
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_model_id,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|img|>",
|
||||
"pad_token": "<pad>",
|
||||
},
|
||||
)
|
||||
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
|
||||
processor = AriaProcessor.from_pretrained(
|
||||
text_model_id,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(text_model_id)
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_index = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
|
||||
}
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AriaForConditionalGeneration(config)
|
||||
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
# print("Saving models")
|
||||
# model.save_pretrained("local_aria", safe_serialization=False)
|
||||
# processor.save_pretrained("local_aria")
|
||||
print("Pushing to hub")
|
||||
model.push_to_hub(output_hub_path, create_pr=True)
|
||||
processor.push_to_hub(output_hub_path, create_pr=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=EPILOG_TXT,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the text model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the vision model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the converted model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_state_dict_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -754,6 +754,9 @@ class AriaTextRotaryEmbedding(nn.Module):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||
# the buffer is automatically moved, but not the original copy)
|
||||
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
|
@ -1,279 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_audio_spectrogram_transformer_config(model_name):
|
||||
config = ASTConfig()
|
||||
|
||||
if "10-10" in model_name:
|
||||
pass
|
||||
elif "speech-commands" in model_name:
|
||||
config.max_length = 128
|
||||
elif "12-12" in model_name:
|
||||
config.time_stride = 12
|
||||
config.frequency_stride = 12
|
||||
elif "14-14" in model_name:
|
||||
config.time_stride = 14
|
||||
config.frequency_stride = 14
|
||||
elif "16-16" in model_name:
|
||||
config.time_stride = 16
|
||||
config.frequency_stride = 16
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "speech-commands" in model_name:
|
||||
config.num_labels = 35
|
||||
filename = "speech-commands-v2-id2label.json"
|
||||
else:
|
||||
config.num_labels = 527
|
||||
filename = "audioset-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "module.v" in name:
|
||||
name = name.replace("module.v", "audio_spectrogram_transformer")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "dist_token" in name:
|
||||
name = name.replace("dist_token", "embeddings.distillation_token")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
# transformer blocks
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
# final layernorm
|
||||
if "audio_spectrogram_transformer.norm" in name:
|
||||
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
|
||||
# classifier head
|
||||
if "module.mlp_head.0" in name:
|
||||
name = name.replace("module.mlp_head.0", "classifier.layernorm")
|
||||
if "module.mlp_head.1" in name:
|
||||
name = name.replace("module.mlp_head.1", "classifier.dense")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.hidden_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
|
||||
] = val[:dim, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
|
||||
] = val[dim : dim * 2, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
|
||||
] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
|
||||
] = val[:dim]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
|
||||
] = val[dim : dim * 2]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
|
||||
] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_keys(state_dict):
|
||||
ignore_keys = [
|
||||
"module.v.head.weight",
|
||||
"module.v.head.bias",
|
||||
"module.v.head_dist.weight",
|
||||
"module.v.head_dist.bias",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
|
||||
"""
|
||||
config = get_audio_spectrogram_transformer_config(model_name)
|
||||
|
||||
model_name_to_url = {
|
||||
"ast-finetuned-audioset-10-10-0.4593": (
|
||||
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.450": (
|
||||
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448": (
|
||||
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448-v2": (
|
||||
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-12-12-0.447": (
|
||||
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-14-14-0.443": (
|
||||
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-16-16-0.442": (
|
||||
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-speech-commands-v2": (
|
||||
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
|
||||
),
|
||||
}
|
||||
|
||||
# load original state_dict
|
||||
checkpoint_url = model_name_to_url[model_name]
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove some keys
|
||||
remove_keys(state_dict)
|
||||
# rename some keys
|
||||
new_state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
# load 🤗 model
|
||||
model = ASTForAudioClassification(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify outputs on dummy input
|
||||
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
|
||||
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
|
||||
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
|
||||
max_length = 1024 if "speech-commands" not in model_name else 128
|
||||
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
|
||||
filename="sample_audio.flac",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
waveform, _ = torchaudio.load(filepath)
|
||||
waveform = waveform.squeeze().numpy()
|
||||
|
||||
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
if model_name == "ast-finetuned-audioset-10-10-0.4593":
|
||||
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.450":
|
||||
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448":
|
||||
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
|
||||
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
|
||||
elif model_name == "ast-finetuned-audioset-12-12-0.447":
|
||||
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
|
||||
elif model_name == "ast-finetuned-audioset-14-14-0.443":
|
||||
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
|
||||
elif model_name == "ast-finetuned-audioset-16-16-0.442":
|
||||
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
|
||||
elif model_name == "ast-finetuned-speech-commands-v2":
|
||||
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError("Logits don't match")
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"MIT/{model_name}")
|
||||
feature_extractor.push_to_hub(f"MIT/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ast-finetuned-audioset-10-10-0.4593",
|
||||
type=str,
|
||||
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -103,6 +103,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("efficientformer", "EfficientFormerConfig"),
|
||||
("efficientnet", "EfficientNetConfig"),
|
||||
("electra", "ElectraConfig"),
|
||||
("emu3", "Emu3Config"),
|
||||
("encodec", "EncodecConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("ernie", "ErnieConfig"),
|
||||
@ -190,6 +191,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTConfig"),
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("modernbert", "ModernBertConfig"),
|
||||
("moonshine", "MoonshineConfig"),
|
||||
("moshi", "MoshiConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mpt", "MptConfig"),
|
||||
@ -419,6 +421,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("efficientformer", "EfficientFormer"),
|
||||
("efficientnet", "EfficientNet"),
|
||||
("electra", "ELECTRA"),
|
||||
("emu3", "Emu3"),
|
||||
("encodec", "EnCodec"),
|
||||
("encoder-decoder", "Encoder decoder"),
|
||||
("ernie", "ERNIE"),
|
||||
@ -519,6 +522,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilevit", "MobileViT"),
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("modernbert", "ModernBERT"),
|
||||
("moonshine", "Moonshine"),
|
||||
("moshi", "Moshi"),
|
||||
("mpnet", "MPNet"),
|
||||
("mpt", "MPT"),
|
||||
|
@ -73,6 +73,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
||||
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
("moonshine", "Wav2Vec2FeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
|
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTModel"),
|
||||
("mobilevitv2", "MobileViTV2Model"),
|
||||
("modernbert", "ModernBertModel"),
|
||||
("moonshine", "MoonshineModel"),
|
||||
("moshi", "MoshiModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mpt", "MptModel"),
|
||||
@ -436,6 +437,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
@ -497,6 +499,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("diffllama", "DiffLlamaForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("falcon_mamba", "FalconMambaForCausalLM"),
|
||||
@ -798,6 +801,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("blip", "BlipForConditionalGeneration"),
|
||||
("blip-2", "Blip2ForConditionalGeneration"),
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
("emu3", "Emu3ForConditionalGeneration"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("idefics", "IdeficsForVisionText2Text"),
|
||||
@ -937,6 +941,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
|
||||
@ -1425,6 +1430,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("electra", "ElectraModel"),
|
||||
("emu3", "Emu3TextModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
|
@ -59,6 +59,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clipseg", "CLIPSegProcessor"),
|
||||
("clvp", "ClvpProcessor"),
|
||||
("colpali", "ColPaliProcessor"),
|
||||
("emu3", "Emu3Processor"),
|
||||
("flava", "FlavaProcessor"),
|
||||
("fuyu", "FuyuProcessor"),
|
||||
("git", "GitProcessor"),
|
||||
@ -81,6 +82,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
|
@ -186,6 +186,7 @@ else:
|
||||
),
|
||||
),
|
||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("esm", ("EsmTokenizer", None)),
|
||||
@ -321,6 +322,7 @@ else:
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
@ -1,273 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: Dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: str = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
)
|
@ -150,6 +150,9 @@ class BambaRotaryEmbedding(nn.Module):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||
# the buffer is automatically moved, but not the original copy)
|
||||
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@ -1197,6 +1200,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -947,6 +947,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -1,263 +0,0 @@
|
||||
"""Convert Bark checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from bark.generation import _load_model as _bark_load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import EncodecConfig, EncodecModel, set_seed
|
||||
from transformers.models.bark.configuration_bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
BarkFineConfig,
|
||||
BarkSemanticConfig,
|
||||
)
|
||||
from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkCoarseGenerationConfig,
|
||||
BarkFineGenerationConfig,
|
||||
BarkGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
set_seed(770)
|
||||
|
||||
|
||||
new_layer_name_dict = {
|
||||
"c_attn": "att_proj",
|
||||
"c_proj": "out_proj",
|
||||
"c_fc": "in_proj",
|
||||
"transformer.": "",
|
||||
"h.": "layers.",
|
||||
"ln_1": "layernorm_1",
|
||||
"ln_2": "layernorm_2",
|
||||
"ln_f": "layernorm_final",
|
||||
"wpe": "position_embeds_layer",
|
||||
"wte": "input_embeds_layer",
|
||||
}
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text.pt",
|
||||
},
|
||||
"coarse_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse.pt",
|
||||
},
|
||||
"fine_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine.pt",
|
||||
},
|
||||
"text": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text_2.pt",
|
||||
},
|
||||
"coarse": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse_2.pt",
|
||||
},
|
||||
"fine": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine_2.pt",
|
||||
},
|
||||
}
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
key = model_type
|
||||
if use_small:
|
||||
key += "_small"
|
||||
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if model_type == "text":
|
||||
ModelClass = BarkSemanticModel
|
||||
ConfigClass = BarkSemanticConfig
|
||||
GenerationConfigClass = BarkSemanticGenerationConfig
|
||||
elif model_type == "coarse":
|
||||
ModelClass = BarkCoarseModel
|
||||
ConfigClass = BarkCoarseConfig
|
||||
GenerationConfigClass = BarkCoarseGenerationConfig
|
||||
elif model_type == "fine":
|
||||
ModelClass = BarkFineModel
|
||||
ConfigClass = BarkFineConfig
|
||||
GenerationConfigClass = BarkFineGenerationConfig
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"])
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
|
||||
# convert Bark model arguments to HF Bark model arguments
|
||||
model_args["num_heads"] = model_args.pop("n_head")
|
||||
model_args["hidden_size"] = model_args.pop("n_embd")
|
||||
model_args["num_layers"] = model_args.pop("n_layer")
|
||||
|
||||
model_config = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(config=model_config)
|
||||
model_generation_config = GenerationConfigClass()
|
||||
|
||||
model.generation_config = model_generation_config
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
# replace part of the key with corresponding layer name in HF implementation
|
||||
new_k = k[len(unwanted_prefix) :]
|
||||
for old_layer_name in new_layer_name_dict:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.num_parameters(exclude_embeddings=True)
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
|
||||
device = "cpu" # do conversion on cpu
|
||||
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
|
||||
|
||||
# load bark initial model
|
||||
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
|
||||
|
||||
if model_type == "text":
|
||||
bark_model = bark_model["model"]
|
||||
|
||||
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
|
||||
raise ValueError("initial and new models don't have the same number of parameters")
|
||||
|
||||
# check if same output as the bark model
|
||||
batch_size = 5
|
||||
sequence_length = 10
|
||||
|
||||
if model_type in ["text", "coarse"]:
|
||||
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
|
||||
output_old_model = bark_model(vec)[0]
|
||||
|
||||
output_new_model_total = model(vec)
|
||||
|
||||
# take last logits
|
||||
output_new_model = output_new_model_total.logits[:, [-1], :]
|
||||
|
||||
else:
|
||||
prediction_codeboook_channel = 3
|
||||
n_codes_total = 8
|
||||
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
|
||||
|
||||
output_new_model_total = model(prediction_codeboook_channel, vec)
|
||||
output_old_model = bark_model(prediction_codeboook_channel, vec)
|
||||
|
||||
output_new_model = output_new_model_total.logits
|
||||
|
||||
# output difference should come from the difference of self-attention implementation design
|
||||
if output_new_model.shape != output_old_model.shape:
|
||||
raise ValueError("initial and new outputs don't have the same shape")
|
||||
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
|
||||
raise ValueError("initial and new outputs are not equal")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_whole_bark_model(
|
||||
semantic_path,
|
||||
coarse_path,
|
||||
fine_path,
|
||||
append_text,
|
||||
hub_path,
|
||||
folder_path,
|
||||
):
|
||||
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
|
||||
|
||||
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
|
||||
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
|
||||
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
|
||||
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
semantic = BarkSemanticModel.from_pretrained(semantic_path)
|
||||
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
|
||||
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
|
||||
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
bark_config = BarkConfig.from_sub_model_configs(
|
||||
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
|
||||
)
|
||||
|
||||
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
|
||||
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
|
||||
)
|
||||
|
||||
bark = BarkModel(bark_config)
|
||||
|
||||
bark.semantic = semantic
|
||||
bark.coarse_acoustics = coarseAcoustic
|
||||
bark.fine_acoustics = fineAcoustic
|
||||
bark.codec_model = codec
|
||||
|
||||
bark.generation_config = bark_generation_config
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def load_xsum_checkpoint(checkpoint_path):
|
||||
"""Checkpoint path should end in model.pt"""
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
||||
hub_interface.model.load_state_dict(sd["model"])
|
||||
return hub_interface
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
||||
else:
|
||||
bart = load_xsum_checkpoint(checkpoint_path)
|
||||
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
if hf_checkpoint_name is None:
|
||||
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
if not torch.eq(tokens, tokens2).all():
|
||||
raise ValueError(
|
||||
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||
)
|
||||
|
||||
if checkpoint_path == "bart.large.mnli":
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
new_model_outputs = model(tokens)[0] # logits
|
||||
else: # no classification heads to worry about
|
||||
state_dict = bart.model.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
else:
|
||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
new_model_outputs = model.model(tokens)[0]
|
||||
|
||||
# Check results
|
||||
if fairseq_output.shape != new_model_outputs.shape:
|
||||
raise ValueError(
|
||||
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||
)
|
||||
if (fairseq_output != new_model_outputs).any().item():
|
||||
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||
)
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
@ -1,373 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BEiT checkpoints from the unilm repository."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitImageProcessor,
|
||||
)
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + shared relative position bias + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
(
|
||||
"rel_pos_bias.relative_position_bias_table",
|
||||
"beit.encoder.relative_position_bias.relative_position_bias_table",
|
||||
),
|
||||
(
|
||||
"rel_pos_bias.relative_position_index",
|
||||
"beit.encoder.relative_position_bias.relative_position_index",
|
||||
),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
# masked image modeling
|
||||
config.use_shared_relative_position_bias = True
|
||||
config.use_mask_token = True
|
||||
has_lm_head = True
|
||||
elif checkpoint_url[-9:-4] == "ft22k":
|
||||
# intermediate fine-tuning on ImageNet-22k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
elif checkpoint_url[-8:-4] == "to1k":
|
||||
# fine-tuning on ImageNet-1k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
if "384" in checkpoint_url:
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
pass
|
||||
elif "large" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
|
||||
)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size([1, 1000])
|
||||
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
|
||||
expected_class_idx = 2397
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
|
||||
expected_class_idx = 2396
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
|
||||
expected_class_idx = 285
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
|
||||
expected_class_idx = 281
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
|
||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
|
||||
|
||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
|
||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
|
||||
|
||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||
|
||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
|
||||
Models trained with never versions are not compatible with this script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
layer_depth = []
|
||||
for full_name, shape in init_vars:
|
||||
# logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
name = full_name.split("/")
|
||||
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
|
||||
logger.info(f"Skipping non-model layer {full_name}")
|
||||
continue
|
||||
if "optimizer" in full_name:
|
||||
logger.info(f"Skipping optimization layer {full_name}")
|
||||
continue
|
||||
if name[0] == "model":
|
||||
# ignore initial 'model'
|
||||
name = name[1:]
|
||||
# figure out how many levels deep the name is
|
||||
depth = 0
|
||||
for _name in name:
|
||||
if _name.startswith("layer_with_weights"):
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
layer_depth.append(depth)
|
||||
# read data
|
||||
array = tf.train.load_variable(tf_path, full_name)
|
||||
names.append("/".join(name))
|
||||
arrays.append(array)
|
||||
logger.info(f"Read a total of {len(arrays):,} layers")
|
||||
|
||||
# Sanity check
|
||||
if len(set(layer_depth)) != 1:
|
||||
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
|
||||
layer_depth = list(set(layer_depth))[0]
|
||||
if layer_depth != 1:
|
||||
raise ValueError(
|
||||
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
|
||||
" heads."
|
||||
)
|
||||
|
||||
# convert layers
|
||||
logger.info("Converting weights...")
|
||||
for full_name, array in zip(names, arrays):
|
||||
name = full_name.split("/")
|
||||
pointer = model
|
||||
trace = []
|
||||
for i, m_name in enumerate(name):
|
||||
if m_name == ".ATTRIBUTES":
|
||||
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
|
||||
break
|
||||
if m_name.startswith("layer_with_weights"):
|
||||
layer_num = int(m_name.split("-")[-1])
|
||||
if layer_num <= 2:
|
||||
# embedding layers
|
||||
# layer_num 0: word_embeddings
|
||||
# layer_num 1: position_embeddings
|
||||
# layer_num 2: token_type_embeddings
|
||||
continue
|
||||
elif layer_num == 3:
|
||||
# embedding LayerNorm
|
||||
trace.extend(["embeddings", "LayerNorm"])
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
|
||||
# encoder layers
|
||||
trace.extend(["encoder", "layer", str(layer_num - 4)])
|
||||
pointer = getattr(pointer, "encoder")
|
||||
pointer = getattr(pointer, "layer")
|
||||
pointer = pointer[layer_num - 4]
|
||||
elif layer_num == config.num_hidden_layers + 4:
|
||||
# pooler layer
|
||||
trace.extend(["pooler", "dense"])
|
||||
pointer = getattr(pointer, "pooler")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "embeddings":
|
||||
trace.append("embeddings")
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
if layer_num == 0:
|
||||
trace.append("word_embeddings")
|
||||
pointer = getattr(pointer, "word_embeddings")
|
||||
elif layer_num == 1:
|
||||
trace.append("position_embeddings")
|
||||
pointer = getattr(pointer, "position_embeddings")
|
||||
elif layer_num == 2:
|
||||
trace.append("token_type_embeddings")
|
||||
pointer = getattr(pointer, "token_type_embeddings")
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding layer with name {full_name}")
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "_attention_layer":
|
||||
# self-attention layer
|
||||
trace.extend(["attention", "self"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "self")
|
||||
elif m_name == "_attention_layer_norm":
|
||||
# output attention norm
|
||||
trace.extend(["attention", "output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_attention_output_dense":
|
||||
# output attention dense
|
||||
trace.extend(["attention", "output", "dense"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_dense":
|
||||
# output dense
|
||||
trace.extend(["output", "dense"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output dense
|
||||
trace.extend(["output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_key_dense":
|
||||
# attention key
|
||||
trace.append("key")
|
||||
pointer = getattr(pointer, "key")
|
||||
elif m_name == "_query_dense":
|
||||
# attention query
|
||||
trace.append("query")
|
||||
pointer = getattr(pointer, "query")
|
||||
elif m_name == "_value_dense":
|
||||
# attention value
|
||||
trace.append("value")
|
||||
pointer = getattr(pointer, "value")
|
||||
elif m_name == "_intermediate_dense":
|
||||
# attention intermediate dense
|
||||
trace.extend(["intermediate", "dense"])
|
||||
pointer = getattr(pointer, "intermediate")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output layer norm
|
||||
trace.append("output")
|
||||
pointer = getattr(pointer, "output")
|
||||
# weights & biases
|
||||
elif m_name in ["bias", "beta"]:
|
||||
trace.append("bias")
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif m_name in ["kernel", "gamma"]:
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
logger.warning(f"Ignored {m_name}")
|
||||
# for certain layers reshape is necessary
|
||||
trace = ".".join(trace)
|
||||
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
|
||||
r"(\S+)\.attention\.output\.dense\.weight", trace
|
||||
):
|
||||
array = array.reshape(pointer.data.shape)
|
||||
if "kernel" in full_name:
|
||||
array = array.transpose()
|
||||
if pointer.shape == array.shape:
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
|
||||
f" {array.shape}"
|
||||
)
|
||||
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
|
||||
return model
|
||||
|
||||
|
||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
|
||||
# Instantiate model
|
||||
logger.info(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from checkpoint
|
||||
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
|
||||
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
|
||||
|
||||
# Save pytorch-model
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model (must include filename).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,112 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model: BertModel Pytorch model instance to be converted
|
||||
ckpt_dir: Tensorflow model directory
|
||||
model_name: model name
|
||||
|
||||
Currently supported HF models:
|
||||
|
||||
- Y BertModel
|
||||
- N BertForMaskedLM
|
||||
- N BertForPreTraining
|
||||
- N BertForMultipleChoice
|
||||
- N BertForNextSentencePrediction
|
||||
- N BertForSequenceClassification
|
||||
- N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return f"bert/{name}"
|
||||
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
session.run(tf_var)
|
||||
return tf_var
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Session() as session:
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any(x in var_name for x in tensors_to_transpose):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
|
||||
tf_weight = session.run(tf_var)
|
||||
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
|
||||
|
||||
saver = tf.train.Saver(tf.trainable_variables())
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,188 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
|
||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
|
||||
|
||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertPooler,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
|
||||
def get_masked_lm_array(name: str):
|
||||
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_array(name: str):
|
||||
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_layer_array(layer_index: int, name: str):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
array = array.reshape(orginal_shape)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
print(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
# Layers
|
||||
for layer_index in range(0, config.num_hidden_layers):
|
||||
layer: BertLayer = model.bert.encoder.layer[layer_index]
|
||||
|
||||
# Self-attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
|
||||
self_attn.query.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
|
||||
)
|
||||
self_attn.query.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
|
||||
)
|
||||
self_attn.key.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
|
||||
)
|
||||
self_attn.key.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
|
||||
)
|
||||
self_attn.value.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
|
||||
)
|
||||
self_attn.value.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
|
||||
)
|
||||
|
||||
# Self-attention Output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
|
||||
self_output.dense.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
|
||||
)
|
||||
self_output.dense.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
|
||||
)
|
||||
|
||||
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
|
||||
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
|
||||
|
||||
# Intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
|
||||
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
|
||||
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
|
||||
|
||||
# Output
|
||||
bert_output: BertOutput = layer.output
|
||||
|
||||
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
|
||||
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
|
||||
|
||||
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
|
||||
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
|
||||
|
||||
# Embeddings
|
||||
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
|
||||
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
|
||||
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
|
||||
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
|
||||
|
||||
# LM Head
|
||||
lm_head = model.cls.predictions.transform
|
||||
|
||||
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
|
||||
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
|
||||
|
||||
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
|
||||
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
|
||||
|
||||
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
|
||||
|
||||
# Pooling
|
||||
model.bert.pooler = BertPooler(config=config)
|
||||
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
|
||||
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
|
||||
|
||||
# Export final model
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Integration test - should load without any errors ;)
|
||||
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
|
||||
print(new_model.eval())
|
||||
|
||||
print("Model conversion was done sucessfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
@ -734,6 +733,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -901,6 +901,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -963,6 +964,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -970,11 +972,12 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
lm_loss = self.loss_function(
|
||||
prediction_scores,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
|
@ -1,69 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigBird checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
|
||||
# Initialise PyTorch model
|
||||
config = BigBirdConfig.from_json_file(big_bird_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
if is_trivia_qa:
|
||||
model = BigBirdForQuestionAnswering(config)
|
||||
else:
|
||||
model = BigBirdForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big_bird_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
|
||||
)
|
@ -1983,6 +1983,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -2540,6 +2541,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -2580,6 +2582,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -2587,11 +2590,12 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
lm_loss = self.loss_function(
|
||||
prediction_scores,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
@ -1,170 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
|
||||
|
||||
|
||||
INIT_COMMON = [
|
||||
# tf -> hf
|
||||
("/", "."),
|
||||
("layer_", "layers."),
|
||||
("kernel", "weight"),
|
||||
("beta", "bias"),
|
||||
("gamma", "weight"),
|
||||
("pegasus", "model"),
|
||||
]
|
||||
END_COMMON = [
|
||||
(".output.dense", ".fc2"),
|
||||
("intermediate.LayerNorm", "final_layer_norm"),
|
||||
("intermediate.dense", "fc1"),
|
||||
]
|
||||
|
||||
DECODER_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.out_proj"),
|
||||
("attention.self", "self_attn"),
|
||||
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
|
||||
("attention.encdec_output.dense", "encoder_attn.out_proj"),
|
||||
("attention.encdec", "encoder_attn"),
|
||||
("key", "k_proj"),
|
||||
("value", "v_proj"),
|
||||
("query", "q_proj"),
|
||||
("decoder.LayerNorm", "decoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
REMAINING_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("embeddings.word_embeddings", "shared.weight"),
|
||||
("embeddings.position_embeddings", "embed_positions.weight"),
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.output"),
|
||||
("attention.self", "self_attn.self"),
|
||||
("encoder.LayerNorm", "encoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
KEYS_TO_IGNORE = [
|
||||
"encdec/key/bias",
|
||||
"encdec/query/bias",
|
||||
"encdec/value/bias",
|
||||
"self/key/bias",
|
||||
"self/query/bias",
|
||||
"self/value/bias",
|
||||
"encdec_output/dense/bias",
|
||||
"attention/output/dense/bias",
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k, patterns):
|
||||
for tf_name, hf_name in patterns:
|
||||
k = k.replace(tf_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
|
||||
cfg = BigBirdPegasusConfig(**config_update)
|
||||
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
|
||||
state_dict = torch_model.state_dict()
|
||||
mapping = {}
|
||||
|
||||
# separating decoder weights
|
||||
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
|
||||
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
|
||||
|
||||
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = DECODER_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = REMAINING_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
|
||||
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
|
||||
missing, extra = torch_model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if k
|
||||
not in [
|
||||
"final_logits_bias",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path) -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any(pat in name for pat in ignore_name)
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
config_update = {}
|
||||
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BioGptConfig, BioGptForCausalLM
|
||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
|
||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(f, "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
|
||||
# prep
|
||||
if not os.path.exists(biogpt_checkpoint_path):
|
||||
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
|
||||
chkpt = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
args = chkpt["cfg"]["model"]
|
||||
|
||||
# dicts
|
||||
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
|
||||
if not os.path.isfile(dict_file):
|
||||
raise ValueError(f"path to the file {dict_file} does not exist!")
|
||||
src_dict = Dictionary.load(dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
|
||||
if not os.path.isfile(bpecodes_file):
|
||||
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
|
||||
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
shutil.copyfile(bpecodes_file, merges_file)
|
||||
|
||||
# model config
|
||||
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
model_conf = {
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"architectures": ["BioGptForCausalLM"],
|
||||
"attention_probs_dropout_prob": args["attention_dropout"],
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": args["activation_fn"],
|
||||
"hidden_dropout_prob": args["dropout"],
|
||||
"hidden_size": args["decoder_embed_dim"],
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": args["decoder_ffn_embed_dim"],
|
||||
"layer_norm_eps": 1e-12,
|
||||
"layerdrop": args["decoder_layerdrop"],
|
||||
"max_position_embeddings": args["max_target_positions"],
|
||||
"model_type": "biogpt",
|
||||
"num_attention_heads": args["decoder_attention_heads"],
|
||||
"num_hidden_layers": args["decoder_layers"],
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_decoder_input_output_embed"],
|
||||
"vocab_size": src_vocab_size,
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
|
||||
print(f"Generating {biogpt_model_config_file}")
|
||||
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": "<pad>",
|
||||
"special_tokens_map_file": None,
|
||||
"tokenizer_class": "BioGptTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
print(f"Generating {biogpt_tokenizer_config_file}")
|
||||
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model_state_dict = chkpt["model"]
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"decoder.version",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
layer_names = list(model_state_dict.keys())
|
||||
for layer_name in layer_names:
|
||||
if layer_name.endswith("output_projection.weight"):
|
||||
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
|
||||
else:
|
||||
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
|
||||
|
||||
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = BioGptForCausalLM(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--biogpt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
|
||||
" bpecodes, etc."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -588,6 +588,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -757,6 +758,7 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -783,11 +785,12 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
lm_loss = self.loss_function(
|
||||
prediction_scores,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
|
@ -1,177 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BiT checkpoints from the timm library."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
|
||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
conv_layer = "std_conv" if "bit" in model_name else False
|
||||
|
||||
# note that when using BiT as backbone for ViT-hybrid checkpoints,
|
||||
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
|
||||
# config.conv_layer = "std_conv_same"
|
||||
config = BitConfig(
|
||||
conv_layer=conv_layer,
|
||||
num_labels=1000,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "stem.conv" in name:
|
||||
name = name.replace("stem.conv", "bit.embedder.convolution")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "layers")
|
||||
if "head.fc" in name:
|
||||
name = name.replace("head.fc", "classifier.1")
|
||||
if name.startswith("norm"):
|
||||
name = "bit." + name
|
||||
if "bit" not in name and "classifier" not in name:
|
||||
name = "bit.encoder." + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BiT structure.
|
||||
"""
|
||||
|
||||
# define default BiT configuration
|
||||
config = get_config(model_name)
|
||||
|
||||
# load original model from timm
|
||||
timm_model = create_model(model_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model
|
||||
state_dict = timm_model.state_dict()
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
|
||||
|
||||
# load HuggingFace model
|
||||
model = BitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create image processor
|
||||
transform = create_transform(**resolve_data_config({}, model=timm_model))
|
||||
timm_transforms = transform.transforms
|
||||
|
||||
pillow_resamplings = {
|
||||
"bilinear": PILImageResampling.BILINEAR,
|
||||
"bicubic": PILImageResampling.BICUBIC,
|
||||
"nearest": PILImageResampling.NEAREST,
|
||||
}
|
||||
|
||||
processor = BitImageProcessor(
|
||||
do_resize=True,
|
||||
size={"shortest_edge": timm_transforms[0].size},
|
||||
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
|
||||
do_center_crop=True,
|
||||
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
|
||||
do_normalize=True,
|
||||
image_mean=timm_transforms[-1].mean.tolist(),
|
||||
image_std=timm_transforms[-1].std.tolist(),
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
timm_pixel_values = transform(image).unsqueeze(0)
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# verify pixel values
|
||||
assert torch.allclose(timm_pixel_values, pixel_values)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Logits:", logits[0, :3])
|
||||
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model {model_name} and processor to the hub")
|
||||
model.push_to_hub(f"ybelkada/{model_name}")
|
||||
processor.push_to_hub(f"ybelkada/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="resnetv2_50x1_bitm",
|
||||
type=str,
|
||||
help="Name of the BiT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,114 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -1,191 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# git clone https://github.com/salesforce/BLIP.git
|
||||
from models.blip import blip_decoder
|
||||
from models.blip_itm import blip_itm
|
||||
from models.blip_vqa import blip_vqa
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BlipConfig,
|
||||
BlipForConditionalGeneration,
|
||||
BlipForImageTextRetrieval,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
def load_demo_image(image_size, device):
|
||||
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
image = transform(raw_image).unsqueeze(0).to(device)
|
||||
return image
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
if "visual_encoder" in key:
|
||||
key = re.sub("visual_encoder*", "vision_model.encoder", key)
|
||||
if "blocks" in key:
|
||||
key = re.sub(r"blocks", "layers", key)
|
||||
if "attn" in key:
|
||||
key = re.sub(r"attn", "self_attn", key)
|
||||
if "norm1" in key:
|
||||
key = re.sub(r"norm1", "layer_norm1", key)
|
||||
if "norm2" in key:
|
||||
key = re.sub(r"norm2", "layer_norm2", key)
|
||||
if "encoder.norm" in key:
|
||||
key = re.sub(r"encoder.norm", "post_layernorm", key)
|
||||
if "encoder.patch_embed.proj" in key:
|
||||
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
|
||||
|
||||
if "encoder.pos_embed" in key:
|
||||
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
|
||||
if "encoder.cls_token" in key:
|
||||
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
|
||||
|
||||
if "self_attn" in key:
|
||||
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = BlipConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = BlipForConditionalGeneration(config).eval()
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
|
||||
|
||||
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
modified_state_dict = pt_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_model.load_state_dict(modified_state_dict)
|
||||
|
||||
image_size = 384
|
||||
image = load_demo_image(image_size=image_size, device="cpu")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
input_ids = tokenizer(["a picture of"]).input_ids
|
||||
|
||||
out = hf_model.generate(image, input_ids)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
out = hf_model.generate(image)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
|
||||
model_url = (
|
||||
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
|
||||
)
|
||||
|
||||
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
|
||||
vqa_model.eval()
|
||||
|
||||
modified_state_dict = vqa_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_vqa_model = BlipForQuestionAnswering(config)
|
||||
|
||||
hf_vqa_model.load_state_dict(modified_state_dict)
|
||||
|
||||
question = ["How many dogs are in this image?"]
|
||||
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
|
||||
|
||||
answer = hf_vqa_model.generate(question_input_ids, image)
|
||||
print(tokenizer.decode(answer[0]))
|
||||
|
||||
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
|
||||
|
||||
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
|
||||
itm_model.eval()
|
||||
|
||||
modified_state_dict = itm_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_itm_model = BlipForImageTextRetrieval(config)
|
||||
|
||||
question = ["A picture of a woman with a dog sitting in a beach"]
|
||||
question_input_ids = tokenizer(
|
||||
question,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
).input_ids
|
||||
|
||||
hf_itm_model.load_state_dict(modified_state_dict)
|
||||
hf_itm_model.eval()
|
||||
|
||||
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
|
||||
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
|
||||
|
||||
assert out[0].item() == 0.2110687494277954
|
||||
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
|
@ -1,390 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert BLIP-2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
Blip2Config,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Processor,
|
||||
Blip2QFormerConfig,
|
||||
Blip2VisionConfig,
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, model_name):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||
if "itm" in model_name:
|
||||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name, eos_token_id):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "opt-2.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "opt-6.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "itm" in model_name:
|
||||
text_config = {}
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
if "itm" in model_name:
|
||||
config = Blip2Config(
|
||||
vision_config=vision_config,
|
||||
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||
)
|
||||
else:
|
||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(
|
||||
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
if "opt" in model_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||
elif "itm" in model_name:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
|
||||
if "itm" in model_name:
|
||||
eos_token_id = None
|
||||
else:
|
||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||
|
||||
if "itm" in model_name:
|
||||
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||
else:
|
||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
|
||||
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
|
||||
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
|
||||
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
|
||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config, model_name)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "opt_proj" in key:
|
||||
key = key.replace("opt_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("opt"):
|
||||
key = key.replace("opt", "language")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert len(missing_keys) == 0
|
||||
|
||||
if "itm" in model_name:
|
||||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||
else:
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
|
||||
if "itm" in model_name:
|
||||
caption = "a large fountain spewing water into the air"
|
||||
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
|
||||
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=False,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
else:
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
hf_model.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
"blip2-itm-vit-g",
|
||||
"blip2-itm-vit-g-coco",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="blip2-opt-2.7b",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
parser.add_argument(
|
||||
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(
|
||||
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||
)
|
@ -1,254 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print("Processing file: {}".format(file))
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors.keys():
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
@ -958,6 +958,8 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
# Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
|
||||
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
@ -990,14 +992,12 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
loss = self.loss_function(
|
||||
lm_logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -1,145 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Bros checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import bros # original repo
|
||||
import torch
|
||||
|
||||
from transformers import BrosConfig, BrosModel, BrosProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_configs(model_name):
|
||||
bros_config = BrosConfig.from_pretrained(model_name)
|
||||
return bros_config
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"embeddings.bbox_sinusoid_emb.inv_freq",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if name == "embeddings.bbox_projection.weight":
|
||||
name = "bbox_embeddings.bbox_projection.weight"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
# rename keys
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
# remove ignore keys
|
||||
remove_ignore_keys_(orig_state_dict)
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
# load original model
|
||||
original_model = bros.BrosModel.from_pretrained(model_name).eval()
|
||||
|
||||
# load HuggingFace Model
|
||||
bros_config = get_configs(model_name)
|
||||
model = BrosModel.from_pretrained(model_name, config=bros_config)
|
||||
model.eval()
|
||||
|
||||
state_dict = original_model.state_dict()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify results
|
||||
|
||||
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
|
||||
bbox = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
processor = BrosProcessor.from_pretrained(model_name)
|
||||
|
||||
encoding = processor("His name is Rocco.", return_tensors="pt")
|
||||
encoding["bbox"] = bbox
|
||||
|
||||
original_hidden_states = original_model(**encoding).last_hidden_state
|
||||
# pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
last_hidden_states = model(**encoding).last_hidden_state
|
||||
|
||||
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="jinho8345/bros-base-uncased",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Name of the original model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,59 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert T5 checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
@ -1584,6 +1584,7 @@ class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1655,11 +1656,12 @@ class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(prediction_scores.device)
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
lm_loss = self.loss_function(
|
||||
prediction_scores,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user