mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 09:44:02 +08:00
Compare commits
15 Commits
check_temp
...
v4.51.1
Author | SHA1 | Date | |
---|---|---|---|
10baffb599 | |||
4a88ffae40 | |||
f19aec737e | |||
d8f0695e84 | |||
d27c8c38f4 | |||
04c0cedcdf | |||
4f536ba0ae | |||
6b82af0a5b | |||
2bf3d4aca8 | |||
a79b7abede | |||
0720e206c6 | |||
25b7f27234 | |||
aa40fda346 | |||
e94571580b | |||
84aa13dd85 |
@ -507,6 +507,8 @@
|
||||
title: Llama2
|
||||
- local: model_doc/llama3
|
||||
title: Llama3
|
||||
- local: model_doc/llama4
|
||||
title: Llama4
|
||||
- local: model_doc/longformer
|
||||
title: Longformer
|
||||
- local: model_doc/longt5
|
||||
|
442
docs/source/en/model_doc/llama4.md
Normal file
442
docs/source/en/model_doc/llama4.md
Normal file
@ -0,0 +1,442 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Llama4
|
||||
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
|
||||
This generation includes two models:
|
||||
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
|
||||
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
|
||||
-
|
||||
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
|
||||
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
|
||||
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
|
||||
|
||||
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
|
||||
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
|
||||
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
|
||||
|
||||
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
|
||||
|
||||
> [!TIP]
|
||||
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
|
||||
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
|
||||
> model.
|
||||
>
|
||||
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
|
||||
> `pip install transformers[hf_xet]`
|
||||
|
||||
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
|
||||
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
|
||||
have context lengths going up to 10 million tokens.
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "what is the recipe of mayonnaise?"},
|
||||
]
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = pipe(messages, do_sample=False, max_new_tokens=200)
|
||||
print(output[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Text only">
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": img_url},
|
||||
{"type": "text", "text": "Describe this image in two sentences."},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal with multiple images">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url1},
|
||||
{"type": "image", "url": url2},
|
||||
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Long context">
|
||||
|
||||
Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
|
||||
We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
|
||||
import torch
|
||||
import time
|
||||
|
||||
file = "very_long_context_prompt.txt"
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
with open(file, "r") as f:
|
||||
very_long_text = "\n".join(f.readlines())
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
attn_implementation="flex_attention",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids.to(model.device),
|
||||
prefill_chunk_size=2048*8,
|
||||
max_new_tokens=300,
|
||||
cache_implementation="hybrid",
|
||||
)
|
||||
print(time.time()-start)
|
||||
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
|
||||
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Efficiency; how to get the best out of llama 4
|
||||
|
||||
### The Attention methods
|
||||
|
||||
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
|
||||
|
||||
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
|
||||
Switching attention mechanism is done at the model initialization step:
|
||||
|
||||
|
||||
<hfoptions id="Attention">
|
||||
<hfoption id="Flex Attention">
|
||||
|
||||
Setting Flex Attention ensures the best results with the very long context the model can handle.
|
||||
|
||||
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
> Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
>
|
||||
> We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
> tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flex_attention",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="SDPA">
|
||||
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Eager">
|
||||
The `eager` attention method is set by default, so no need for anything different when loading the model:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Quantization
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
|
||||
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
|
||||
|
||||
See below for examples using both:
|
||||
|
||||
|
||||
|
||||
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
|
||||
|
||||
<hfoptions id="Quantization">
|
||||
<hfoption id="FBGEMM">
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=FbgemmFp8Config()
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="LLM-Compressor">
|
||||
|
||||
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Offloading
|
||||
|
||||
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
|
||||
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
|
||||
However, this also slows down inference as it adds communication overhead.
|
||||
|
||||
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Llama4Config
|
||||
|
||||
[[autodoc]] Llama4Config
|
||||
|
||||
## Llama4TextConfig
|
||||
|
||||
[[autodoc]] Llama4TextConfig
|
||||
|
||||
## Llama4VisionConfig
|
||||
|
||||
[[autodoc]] Llama4VisionConfig
|
||||
|
||||
## Llama4Processor
|
||||
|
||||
[[autodoc]] Llama4Processor
|
||||
|
||||
## Llama4ImageProcessorFast
|
||||
|
||||
[[autodoc]] Llama4ImageProcessorFast
|
||||
|
||||
## Llama4ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Llama4ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4TextModel
|
||||
|
||||
[[autodoc]] Llama4TextModel
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4VisionModel
|
||||
|
||||
[[autodoc]] Llama4VisionModel
|
||||
- forward
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.51.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
4
setup.py
4
setup.py
@ -117,6 +117,7 @@ _deps = [
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"hf_xet",
|
||||
"huggingface-hub>=0.30.0,<1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
@ -283,6 +284,7 @@ extras["tf-cpu"] = deps_list(
|
||||
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
extras["accelerate"] = deps_list("accelerate")
|
||||
extras["hf_xet"] = deps_list("hf_xet")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows
|
||||
@ -451,7 +453,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.51.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.51.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.51.0.dev0"
|
||||
__version__ = "4.51.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -562,6 +562,12 @@ _import_structure = {
|
||||
"models.levit": ["LevitConfig"],
|
||||
"models.lilt": ["LiltConfig"],
|
||||
"models.llama": ["LlamaConfig"],
|
||||
"models.llama4": [
|
||||
"Llama4Config",
|
||||
"Llama4Processor",
|
||||
"Llama4TextConfig",
|
||||
"Llama4VisionConfig",
|
||||
],
|
||||
"models.llava": [
|
||||
"LlavaConfig",
|
||||
"LlavaProcessor",
|
||||
@ -1354,6 +1360,7 @@ else:
|
||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
|
||||
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
|
||||
_import_structure["models.llama4"].append("Llama4ImageProcessorFast")
|
||||
_import_structure["models.llava"].append("LlavaImageProcessorFast")
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
|
||||
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
|
||||
@ -2510,6 +2517,15 @@ else:
|
||||
"GlmPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.llama4"].extend(
|
||||
[
|
||||
"Llama4ForCausalLM",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Llama4TextModel",
|
||||
"Llama4VisionModel",
|
||||
"Llama4PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.glpn"].extend(
|
||||
[
|
||||
"GLPNForDepthEstimation",
|
||||
@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
|
||||
from .models.levit import LevitConfig
|
||||
from .models.lilt import LiltConfig
|
||||
from .models.llama import LlamaConfig
|
||||
from .models.llama4 import (
|
||||
Llama4Config,
|
||||
Llama4Processor,
|
||||
Llama4TextConfig,
|
||||
Llama4VisionConfig,
|
||||
)
|
||||
from .models.llava import (
|
||||
LlavaConfig,
|
||||
LlavaProcessor,
|
||||
@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
|
||||
from .models.detr import DetrImageProcessorFast
|
||||
from .models.gemma3 import Gemma3ImageProcessorFast
|
||||
from .models.got_ocr2 import GotOcr2ImageProcessorFast
|
||||
from .models.llama4 import Llama4ImageProcessorFast
|
||||
from .models.llava import LlavaImageProcessorFast
|
||||
from .models.llava_next import LlavaNextImageProcessorFast
|
||||
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
|
||||
@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from .models.llama4 import (
|
||||
Llama4ForCausalLM,
|
||||
Llama4ForConditionalGeneration,
|
||||
Llama4PreTrainedModel,
|
||||
Llama4TextModel,
|
||||
Llama4VisionModel,
|
||||
)
|
||||
from .models.llava import (
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaPreTrainedModel,
|
||||
|
@ -1811,6 +1811,204 @@ class HybridCache(Cache):
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
|
||||
class HybridChunkedCache(Cache):
|
||||
"""
|
||||
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
|
||||
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
|
||||
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
|
||||
|
||||
Parameters:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||
smaller batch size is used.
|
||||
max_cache_len (`int`, *optional*):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is split between different gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values # access cache filled with key/values from generation
|
||||
HybridCache()
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
||||
# ALL changes from the PR that commented the line below when reactivating it.
|
||||
is_compileable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int] = None,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
|
||||
else:
|
||||
self.sliding_window = config.sliding_window
|
||||
self.max_cache_len = max_cache_len
|
||||
self.max_batch_size = max_batch_size
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self._dtype = dtype
|
||||
|
||||
if hasattr(config.get_text_config(), "no_rope_layers"):
|
||||
self.is_sliding = config.no_rope_layers
|
||||
else:
|
||||
layer_switch = getattr(config, "sliding_window_pattern", 2)
|
||||
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def initialise_cache_layer(self, layer_idx, key_states):
|
||||
if len(self.key_cache) > layer_idx:
|
||||
return
|
||||
|
||||
num_key_value_heads = key_states.shape[1]
|
||||
device = key_states.device
|
||||
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
self.max_batch_size,
|
||||
num_key_value_heads,
|
||||
self.sliding_window,
|
||||
self.head_dim,
|
||||
)
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
k_out = key_states[:, :, -max_cache_len:, :]
|
||||
v_out = value_states[:, :, -max_cache_len:, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
|
||||
# otherwise we are decoding. Most efficient way to cat 1 token
|
||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
to_shift = cache_position >= max_cache_len - 1
|
||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
return k_out, v_out
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
self.key_cache[layer_idx] = k_out
|
||||
self.value_cache[layer_idx] = v_out
|
||||
return k_out, v_out
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
self.initialise_cache_layer(layer_idx, key_states)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if self.is_sliding[layer_idx]:
|
||||
update_fn = self._sliding_update
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
key_states,
|
||||
value_states,
|
||||
k_out,
|
||||
v_out,
|
||||
k_out.shape[2],
|
||||
)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||
# limit the check to the first batch member and head dimension.
|
||||
# TODO: deprecate this function in favor of `cache_position`
|
||||
if layer_idx != 0:
|
||||
raise ValueError(
|
||||
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
||||
"Using the `layer_idx` argument is not supported."
|
||||
)
|
||||
if len(self.key_cache) == 0:
|
||||
return 0
|
||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the cache values while preserving the objects"""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
def to_diff_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Removes all attributes from config which correspond to the default config attributes for better readability and
|
||||
serializes to a Python dictionary.
|
||||
Removes all attributes from the configuration that correspond to the default config attributes for
|
||||
better readability, while always retaining the `config` attribute from the class. Serializes to a
|
||||
Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
|
||||
# get the default config dict
|
||||
# Get the default config dict (from a fresh PreTrainedConfig instance)
|
||||
default_config_dict = PretrainedConfig().to_dict()
|
||||
|
||||
# get class specific config dict
|
||||
# Get class-specific config dict if not part of a composition
|
||||
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
|
||||
|
||||
serializable_config_dict = {}
|
||||
@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
@ -24,6 +24,7 @@ deps = {
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"hf_xet": "hf_xet",
|
||||
"huggingface-hub": "huggingface-hub>=0.30.0,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
|
@ -52,6 +52,7 @@ if is_torch_available():
|
||||
from ..cache_utils import (
|
||||
HQQQuantizedCache,
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
MambaCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCacheConfig,
|
||||
@ -69,6 +70,7 @@ if is_torch_available():
|
||||
"offloaded_static": OffloadedStaticCache,
|
||||
"sliding_window": SlidingWindowCache,
|
||||
"hybrid": HybridCache,
|
||||
"hybrid_chunked": HybridChunkedCache,
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
if isinstance(self.cache_config, dict):
|
||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
||||
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
|
||||
|
||||
# Parameters for manipulation of the model output logits
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
|
@ -1830,6 +1830,9 @@ class GenerationMixin:
|
||||
|
||||
Returns the resulting cache object.
|
||||
"""
|
||||
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
|
||||
cache_implementation = "hybrid_chunked"
|
||||
|
||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
@ -1958,6 +1961,9 @@ class GenerationMixin:
|
||||
)
|
||||
generation_config.cache_implementation = None
|
||||
|
||||
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
|
||||
self.config.get_text_config(), "cache_implementation", None
|
||||
)
|
||||
if generation_config.cache_implementation is not None:
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
||||
@ -3405,7 +3411,12 @@ class GenerationMixin:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
is_prefill = True
|
||||
if generation_config.prefill_chunk_size is not None:
|
||||
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
||||
is_prefill = False
|
||||
else:
|
||||
is_prefill = True
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
@ -4855,6 +4866,45 @@ class GenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
|
||||
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
|
||||
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
chunk_size = generation_config.prefill_chunk_size
|
||||
# Only chunk up the token just before last, so that decoding is completely performed outside this function
|
||||
# (here we simply prefill the cache)
|
||||
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
|
||||
|
||||
if "past_key_values" not in model_kwargs:
|
||||
raise ValueError("Cannot use prefill chunkink without a cache")
|
||||
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
attention_mask = model_kwargs.pop("attention_mask", None)
|
||||
|
||||
past_length = 0
|
||||
for input_chunk in input_chunks:
|
||||
current_length = past_length + input_chunk.shape[-1]
|
||||
# Prepare inputs
|
||||
if attention_mask is not None:
|
||||
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
|
||||
model_kwargs["cache_position"] = torch.arange(
|
||||
past_length, current_length, dtype=torch.long, device=input_chunk.device
|
||||
)
|
||||
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
|
||||
|
||||
outputs = model_forward(**model_inputs, return_dict=True)
|
||||
|
||||
model_kwargs["past_key_values"] = outputs.past_key_values
|
||||
past_length = current_length
|
||||
|
||||
model_kwargs["attention_mask"] = attention_mask
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
||||
_ = model_kwargs.pop("position_ids", None)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def _speculative_sampling(
|
||||
candidate_input_ids,
|
||||
|
@ -53,7 +53,7 @@ _import_structure = {
|
||||
"unset_hf_deepspeed_config",
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
|
||||
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
@ -192,7 +192,7 @@ if TYPE_CHECKING:
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
|
||||
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
|
||||
from .fsdp import is_fsdp_managed_module
|
||||
from .ggml import (
|
||||
|
196
src/transformers/integrations/accelerate.py
Normal file
196
src/transformers/integrations/accelerate.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
|
||||
device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
|
||||
somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
|
||||
|
||||
The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
|
||||
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_empty_weights(include_buffers: bool = False):
|
||||
"""
|
||||
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
||||
empty model. Useful when just initializing the model would blow the available RAM.
|
||||
|
||||
Args:
|
||||
include_buffers (`bool`, *optional*):
|
||||
Whether or not to also put all buffers on the meta device while initializing.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
||||
with init_empty_weights():
|
||||
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Any model created under this context manager has no weights. As such you can't do something like
|
||||
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
||||
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
||||
called.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
||||
yield f
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_on_device(device: "torch.device", include_buffers: bool = False):
|
||||
"""
|
||||
A context manager under which models are initialized with all parameters on the specified device.
|
||||
|
||||
Args:
|
||||
device (`torch.device`):
|
||||
Device to initialize all parameters on.
|
||||
include_buffers (`bool`, *optional*):
|
||||
Whether or not to also put all buffers on the meta device while initializing.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
from accelerate import init_on_device
|
||||
|
||||
with init_on_device(device=torch.device("cuda")):
|
||||
tst = nn.Linear(100, 100) # on `cuda` device
|
||||
```
|
||||
"""
|
||||
if include_buffers:
|
||||
with device:
|
||||
yield
|
||||
return
|
||||
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
# Patch tensor creation
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
|
||||
def find_tied_parameters(model: "nn.Module", **kwargs):
|
||||
"""
|
||||
Find the tied parameters in a given model.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||
them.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to inspect.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of lists of parameter names being all tied together.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from collections import OrderedDict
|
||||
>>> import torch.nn as nn
|
||||
|
||||
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||
>>> model.linear2.weight = model.linear1.weight
|
||||
>>> find_tied_parameters(model)
|
||||
[['linear1.weight', 'linear2.weight']]
|
||||
```
|
||||
"""
|
||||
|
||||
# get ALL model parameters and thier names
|
||||
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
|
||||
|
||||
# get ONLY unique named parameters,
|
||||
# if parameter is tied and have multiple names, it will be included only once
|
||||
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
|
||||
|
||||
# the difference of the two sets will give us the tied parameters
|
||||
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
|
||||
|
||||
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
|
||||
# which names refer to the same parameter. To identify this, we need to group them together.
|
||||
tied_param_groups = {}
|
||||
for tied_param_name in tied_param_names:
|
||||
tied_param = all_named_parameters[tied_param_name]
|
||||
for param_name, param in no_duplicate_named_parameters.items():
|
||||
# compare if parameters are the same, if so, group thier names together
|
||||
if param is tied_param:
|
||||
if param_name not in tied_param_groups:
|
||||
tied_param_groups[param_name] = []
|
||||
tied_param_groups[param_name].append(tied_param_name)
|
||||
|
||||
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|
54
src/transformers/integrations/compressed_tensors.py
Normal file
54
src/transformers/integrations/compressed_tensors.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
|
||||
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class CompressedExpertsLinear(nn.Module):
|
||||
"""
|
||||
A module that implements a compressed version of a list of expert modules.
|
||||
This is specifically designed to work with Llama4TextExperts in MoE layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# Skip random weight initialization for experts. Otherwise,
|
||||
# the init of this module would take over minutes. For a model
|
||||
# with tens of layers of experts, it would easily take over 20 minutes.
|
||||
nn.init.kaiming_uniform_ = skip
|
||||
nn.init.uniform_ = skip
|
||||
nn.init.normal_ = skip
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
|
||||
expert_routed_out_list = []
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
|
||||
routed_out = torch.cat(expert_routed_out_list, dim=0)
|
||||
return routed_out
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..activations import ACT2FN
|
||||
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FbgemmFp8Linear(torch.nn.Module):
|
||||
class FbgemmFp8Linear(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
|
||||
super().__init__()
|
||||
super().__init__(in_features, out_features, bias)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
|
||||
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
|
||||
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
|
||||
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
|
||||
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
||||
|
||||
if bias:
|
||||
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
|
||||
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
num_tokens = None
|
||||
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
|
||||
output_shape = (*x.shape[:-1], -1)
|
||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
|
||||
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
|
||||
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
|
||||
weight_scale_float32 = self.weight_scale.to(torch.float32)
|
||||
output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
|
||||
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
|
||||
)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
# Hacky for now, we have the output to the device of x
|
||||
@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class FbgemmFp8Llama4TextExperts(nn.Module):
|
||||
def __init__(self, config, dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
# Register FP8 buffers for gate_up_proj
|
||||
self.gate_up_proj = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
|
||||
)
|
||||
self.gate_up_proj_scale = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
|
||||
)
|
||||
# Register FP8 buffers for down_proj
|
||||
self.down_proj = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
|
||||
)
|
||||
self.down_proj_scale = torch.nn.Parameter(
|
||||
torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
|
||||
)
|
||||
# Register input scale upper bound
|
||||
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
|
||||
Returns:
|
||||
torch.Tensor: (batch_size * token_num, hidden_size)
|
||||
"""
|
||||
# Reshape hidden states for expert computation
|
||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||
num_tokens = None
|
||||
|
||||
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
|
||||
next_states = torch.empty_like(hidden_states)
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# Extract expert's hidden states
|
||||
expert_hidden = hidden_states[i]
|
||||
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
|
||||
# Quantize for this expert
|
||||
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
expert_hidden_reshaped, num_tokens, self.input_scale_ub
|
||||
)
|
||||
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
|
||||
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
|
||||
|
||||
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
expert_quantized,
|
||||
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
|
||||
expert_scale,
|
||||
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
up = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
expert_quantized,
|
||||
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
|
||||
expert_scale,
|
||||
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
activated = up * self.act_fn(gate)
|
||||
|
||||
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
activated, num_tokens, self.input_scale_ub
|
||||
)
|
||||
|
||||
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
|
||||
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
activated_quantized,
|
||||
self.down_proj[i].transpose(0, 1).contiguous(),
|
||||
activated_scale,
|
||||
down_proj_scale_float32[i].view(-1, 1).contiguous(),
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
next_states[i] = expert_output
|
||||
next_states = next_states.to(hidden_states.device)
|
||||
return next_states.view(-1, self.hidden_size)
|
||||
|
||||
|
||||
def _replace_with_fbgemm_fp8_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
pre_quantized=False,
|
||||
config=None,
|
||||
tp_plan=None,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# set non persistant buffer outside of init_empty_weights
|
||||
model._modules[name].input_scale_ub = torch.tensor(
|
||||
[quantization_config.activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
)
|
||||
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights(include_buffers=True):
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
|
||||
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
|
||||
]
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
|
||||
model._modules[name] = FbgemmFp8Llama4TextExperts(
|
||||
config.text_config,
|
||||
)
|
||||
model._modules[name].input_scale_ub = torch.tensor(
|
||||
[quantization_config.activation_scale_ub], dtype=torch.float
|
||||
)
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
module,
|
||||
@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
pre_quantized=pre_quantized,
|
||||
config=config,
|
||||
tp_plan=tp_plan,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
|
||||
|
||||
def replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
pre_quantized=False,
|
||||
config=None,
|
||||
tp_plan=None,
|
||||
):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
|
||||
@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
|
||||
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
|
||||
model,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
pre_quantized=pre_quantized,
|
||||
config=config,
|
||||
tp_plan=tp_plan,
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model using FP8 quantization but no linear modules were found in your model."
|
||||
|
@ -31,13 +31,11 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from ..utils import is_torch_flex_attn_available
|
||||
from ..utils.import_utils import _torch_version
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
@ -59,19 +57,36 @@ class WrappedFlexAttention:
|
||||
return cls._instance
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def __init__(self):
|
||||
def __init__(self, training):
|
||||
"""
|
||||
Initialize or update the singleton instance.
|
||||
"""
|
||||
if self._is_flex_compiled is False:
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
if not self._is_flex_compiled:
|
||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||
if _torch_version == "2.6.0" and training:
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
else:
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
return self._compiled_flex_attention
|
||||
|
||||
|
||||
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
Offset = Union[torch.Tensor, int]
|
||||
|
||||
|
||||
def make_flex_block_causal_mask(
|
||||
attention_mask_2d: torch.Tensor,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
query_length=None,
|
||||
key_length=None,
|
||||
offsets: Optional[Tuple[Offset, Offset]] = None,
|
||||
) -> "BlockMask":
|
||||
"""
|
||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||
@ -94,10 +109,18 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
Returns:
|
||||
BlockMask
|
||||
"""
|
||||
batch_size, total_seq_len = attention_mask_2d.shape
|
||||
if not key_length:
|
||||
key_length = total_seq_len
|
||||
if not query_length:
|
||||
query_length = total_seq_len
|
||||
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
|
||||
device = attention_mask_2d.device
|
||||
document_ids = attention_mask_2d.clone()
|
||||
|
||||
document_ids = attention_mask_2d
|
||||
batch_size, total_seq_len = document_ids.shape
|
||||
if attention_chunk_size is not None:
|
||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
||||
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
# that determines which elements of QK^T should be included in the attention
|
||||
@ -112,18 +135,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask = q_idx >= kv_idx
|
||||
causal_mask = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
padding_mask = document_ids[batch_idx, q_idx] > 0
|
||||
return causal_mask & document_mask & padding_mask
|
||||
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
|
||||
final_mask = causal_mask & padding_mask & document_mask
|
||||
return final_mask
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
offset_kv = kv_idx + kv_offset
|
||||
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
|
||||
else:
|
||||
mask_mod = causal_mask_mod
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=causal_mask_mod,
|
||||
mask_mod=mask_mod,
|
||||
B=batch_size,
|
||||
H=None, # attention head
|
||||
Q_LEN=total_seq_len,
|
||||
KV_LEN=total_seq_len,
|
||||
Q_LEN=query_length,
|
||||
KV_LEN=key_length,
|
||||
device=device,
|
||||
_compile=True,
|
||||
)
|
||||
|
||||
|
||||
@ -132,10 +167,11 @@ def compile_friendly_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
training=False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
||||
flex_attention_compiled = WrappedFlexAttention()()
|
||||
flex_attention_compiled = WrappedFlexAttention(training)()
|
||||
return flex_attention_compiled(
|
||||
query,
|
||||
key,
|
||||
@ -144,6 +180,18 @@ def compile_friendly_flex_attention(
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flex_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -174,17 +222,29 @@ def flex_attention_forward(
|
||||
score = score + head_mask[batch_idx][head_idx][0][0]
|
||||
return score
|
||||
|
||||
enable_gqa = True
|
||||
num_local_query_heads = query.shape[1]
|
||||
|
||||
# When running TP this helps:
|
||||
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
|
||||
key = repeat_kv(key, query.shape[1] // key.shape[1])
|
||||
value = repeat_kv(value, query.shape[1] // value.shape[1])
|
||||
enable_gqa = False
|
||||
|
||||
kernel_options = kwargs.get("kernel_options", None)
|
||||
attn_output, attention_weights = compile_friendly_flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
kernel_options=kernel_options,
|
||||
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
||||
# For simplification, we thus always return it as no additional computations are introduced.
|
||||
return_lse=True,
|
||||
training=module.training,
|
||||
)
|
||||
# lse is returned in float32
|
||||
attention_weights = attention_weights.to(value.dtype)
|
||||
|
@ -31,7 +31,7 @@ def sdpa_attention_forward(
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
if attention_mask is not None and causal_mask.ndim == 4:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
|
||||
|
@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
}
|
||||
|
||||
|
||||
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
|
||||
@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||
block_offset += block_size
|
||||
|
||||
slice_dtype = slice_.get_dtype()
|
||||
# Handle F8_E4M3 dtype by converting to float16 before slicing
|
||||
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
|
||||
if slice_dtype == "F8_E4M3":
|
||||
slice_ = slice_[...].to(torch.float16)
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[tensors_slices, ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
tensor = slice_[..., tensors_slices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
return tensor
|
||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if isinstance(inputs[0], DTensor):
|
||||
inputs[0] = inputs[0].to_local()
|
||||
inputs = inputs[0].to_local()
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# this op cannot be asynch, otherwise it completely breaks the outputs of models
|
||||
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
# transform the input layouts to the desired layouts of ColwiseParallel
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
|
||||
# back to local tensor
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if hasattr(mod, "bias") and mod.bias is not None:
|
||||
mod._bias = mod.bias
|
||||
mod.bias = None
|
||||
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# Rowwise sharding produces partial output, depending on output layouts:
|
||||
@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
# 2. to shard -> reduce_scatter
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
if hasattr(mod, "_bias"):
|
||||
outputs += mod._bias
|
||||
# back to local tensor if use_local_output is True
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
class SequenceParallel(TensorParallelLayer):
|
||||
"""
|
||||
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
|
||||
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
|
||||
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
|
||||
|
||||
This style implements the operation that is described in the paper
|
||||
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
|
||||
|
||||
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
|
||||
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
|
||||
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
|
||||
redistribute the input to be sharded on the sequence dimension.
|
||||
|
||||
The output of the ``nn.Module`` will be sharded on the sequence dimension.
|
||||
|
||||
Keyword Args:
|
||||
sequence_dim (int, optional):
|
||||
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
|
||||
become a DTensor that is sharded on the sequence dimension, default: 1.
|
||||
use_local_output (bool, optional):
|
||||
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
|
||||
Returns:
|
||||
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP(failing)
|
||||
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
|
||||
>>> from torch.distributed.device_mesh import init_device_mesh
|
||||
>>> ...
|
||||
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
|
||||
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
||||
>>>
|
||||
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
|
||||
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
|
||||
>>>
|
||||
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
|
||||
>>> ...
|
||||
|
||||
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
|
||||
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
|
||||
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = True
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
outputs = outputs.redistribute(
|
||||
placements=(Replicate(),), async_op=True
|
||||
) # maybe we have to replicate ? because next layer is not sharded
|
||||
return outputs.to_local() # if use_local_output else outputs
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter = param[:]
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
SUPPORTED_TP_STYLES = {
|
||||
"colwise",
|
||||
"rowwise",
|
||||
@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
|
||||
"local",
|
||||
"gather",
|
||||
"local_packed_rowwise",
|
||||
"sequence_parallel",
|
||||
}
|
||||
|
||||
|
||||
@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
|
||||
return GatherParallel()
|
||||
elif style == "local_packed_rowwise":
|
||||
return PackedRowwiseParallel(use_dtensor=False)
|
||||
elif style == "sequence_parallel":
|
||||
return SequenceParallel()
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
@ -518,6 +633,7 @@ def shard_and_distribute_module(
|
||||
tp_plan = model._tp_plan
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
current_module_plan = None
|
||||
rank = int(rank)
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name]
|
||||
@ -531,12 +647,18 @@ def shard_and_distribute_module(
|
||||
module_to_tp._is_hooked = True
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
try:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print(
|
||||
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
||||
)
|
||||
else:
|
||||
# TODO log no plan modules in set
|
||||
# print("No plan for", parameter_name,end ="\n")
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if is_contiguous:
|
||||
param = param.contiguous()
|
||||
|
@ -57,7 +57,8 @@ from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
@ -131,12 +132,11 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
from accelerate.utils import (
|
||||
check_tied_parameters_on_same_device,
|
||||
extract_model_from_parallel,
|
||||
find_tied_parameters,
|
||||
get_balanced_memory,
|
||||
get_max_memory,
|
||||
load_offloaded_weights,
|
||||
@ -153,6 +153,10 @@ if is_safetensors_available():
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -480,6 +484,7 @@ str_to_torch_dtype = {
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
}
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
@ -1910,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._pp_plan = (
|
||||
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
)
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
else:
|
||||
self._tp_plan = self._tp_plan or {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
@ -2021,8 +2021,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||
# and memory copying it on CPU or each GPU first
|
||||
@ -2662,8 +2660,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
|
||||
vocab_size = model_embeds.weight.shape[0]
|
||||
else:
|
||||
@ -2694,8 +2690,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Update new_num_tokens with the actual size of new_embeddings
|
||||
if pad_to_multiple_of is not None:
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
|
||||
new_num_tokens = new_embeddings.weight.shape[0]
|
||||
else:
|
||||
@ -2784,8 +2778,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
else:
|
||||
@ -2830,8 +2822,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
added_num_tokens = new_num_tokens - old_num_tokens
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
||||
self._init_added_embeddings_weights_with_mean(
|
||||
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
|
||||
@ -2847,8 +2837,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
n = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
@ -2859,8 +2847,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# This ensures correct functionality when a Custom Embedding class is passed as input.
|
||||
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_embeddings.weight, new_embeddings.weight]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
old_embeddings.weight = new_embeddings.weight
|
||||
@ -2918,8 +2904,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
|
||||
old_num_tokens, old_lm_head_dim = (
|
||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||
@ -2970,8 +2954,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
added_num_tokens = new_num_tokens - old_num_tokens
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_lm_head.weight]
|
||||
if has_new_lm_head_bias:
|
||||
params += [old_lm_head.bias]
|
||||
@ -2992,8 +2974,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
self._copy_lm_head_original_to_resized(
|
||||
@ -3738,25 +3718,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return super().float(*args)
|
||||
|
||||
@classmethod
|
||||
def get_init_context(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
is_quantized=None,
|
||||
_is_ds_init_called=None,
|
||||
):
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts = [
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
no_init_weights(),
|
||||
]
|
||||
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
init_contexts = [no_init_weights()]
|
||||
# We cannot initialize the model on meta device with deepspeed when not quantized
|
||||
if not is_quantized and not _is_ds_init_called:
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
||||
elif is_quantized:
|
||||
init_contexts.extend([init_empty_weights(), set_quantized_state()])
|
||||
else:
|
||||
init_contexts = [no_init_weights(), init_empty_weights()]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
return init_contexts
|
||||
|
||||
@classmethod
|
||||
@ -4072,6 +4045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
# Assuming sharding the model onto the world
|
||||
@ -4161,6 +4135,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if device_map is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||
if not is_accelerate_available():
|
||||
raise ValueError(
|
||||
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
|
||||
)
|
||||
|
||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||
if load_in_4bit or load_in_8bit:
|
||||
@ -4256,6 +4234,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
device_map = hf_quantizer.update_device_map(device_map)
|
||||
config = hf_quantizer.update_tp_plan(config)
|
||||
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
|
||||
@ -4388,9 +4367,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
|
||||
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
|
||||
)
|
||||
|
||||
# We store the original dtype for quantized models as we cannot easily retrieve it
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
@ -4644,6 +4622,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
):
|
||||
# Useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
@ -4805,15 +4788,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None: # and hf_quantizer is None:
|
||||
if device_map is not None and not is_hqq:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
error_msgs = []
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
# Iterate on all the shards to load the weights
|
||||
for shard_file in checkpoint_files:
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
@ -4821,7 +4800,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
continue
|
||||
|
||||
map_location = "cpu"
|
||||
if shard_file.endswith(".safetensors") and not is_hqq_or_bnb:
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
elif (
|
||||
device_map is not None
|
||||
@ -4843,7 +4826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
@ -4919,7 +4902,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
tp_device.index,
|
||||
os.environ["RANK"],
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
@ -5192,6 +5175,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
||||
(where we want the speed-ups of compiled version with static shapes)."""
|
||||
# Only reset it if not present or different from previous config
|
||||
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
||||
return self.__call__
|
||||
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
|
||||
if (
|
||||
not hasattr(self, "_compiled_call")
|
||||
@ -5267,8 +5252,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
not_initialized_submodules = dict(self.named_modules())
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
not_initialized_parameters = list(
|
||||
set(
|
||||
itertools.chain.from_iterable(
|
||||
|
@ -148,6 +148,7 @@ from . import (
|
||||
levit,
|
||||
lilt,
|
||||
llama,
|
||||
llama4,
|
||||
llava,
|
||||
llava_next,
|
||||
llava_next_video,
|
||||
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = AlbertConfig.from_json_file(albert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = AlbertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
@ -579,6 +579,8 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AlbertMLMHead):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,389 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALIGN checkpoints from the original repository."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import align
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from transformers import (
|
||||
AlignConfig,
|
||||
AlignModel,
|
||||
AlignProcessor,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
EfficientNetConfig,
|
||||
EfficientNetImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
image = tf.image.resize(image, (346, 346))
|
||||
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
|
||||
return image
|
||||
|
||||
|
||||
def get_align_config():
|
||||
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
|
||||
vision_config.image_size = 289
|
||||
vision_config.hidden_dim = 640
|
||||
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
|
||||
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
|
||||
vision_config.depthwise_padding = []
|
||||
|
||||
text_config = BertConfig()
|
||||
config = AlignConfig.from_text_vision_configs(
|
||||
text_config=text_config, vision_config=vision_config, projection_dim=640
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def get_processor():
|
||||
image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
rescale_offset=True,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
tokenizer.model_max_length = 64
|
||||
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def rename_keys(original_param_names):
|
||||
# EfficientNet image encoder
|
||||
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
|
||||
block_names = list(set(block_names))
|
||||
block_names = sorted(block_names)
|
||||
num_blocks = len(block_names)
|
||||
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
|
||||
|
||||
rename_keys = []
|
||||
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
|
||||
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
|
||||
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
|
||||
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
|
||||
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
|
||||
|
||||
for b in block_names:
|
||||
hf_b = block_name_mapping[b]
|
||||
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
|
||||
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
|
||||
)
|
||||
|
||||
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
|
||||
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
|
||||
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
|
||||
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
|
||||
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
|
||||
)
|
||||
|
||||
key_mapping = {}
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = "vision_model." + item[1]
|
||||
|
||||
# BERT text encoder
|
||||
rename_keys = []
|
||||
old = "tf_bert_model/bert"
|
||||
new = "text_model"
|
||||
for i in range(12):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
|
||||
)
|
||||
|
||||
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
|
||||
)
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
|
||||
|
||||
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
|
||||
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
|
||||
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("temperature:0", "temperature"))
|
||||
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = item[1]
|
||||
return key_mapping
|
||||
|
||||
|
||||
def replace_params(hf_params, tf_params, key_mapping):
|
||||
list(hf_params.keys())
|
||||
|
||||
for key, value in tf_params.items():
|
||||
if key not in key_mapping:
|
||||
continue
|
||||
|
||||
hf_key = key_mapping[key]
|
||||
if "_conv" in key and "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
|
||||
elif "embeddings" in key:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
elif "depthwise_kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
|
||||
elif "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
# Replace HF parameters with original TF model parameters
|
||||
hf_params[hf_key].copy_(new_hf_value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ALIGN structure.
|
||||
"""
|
||||
# Load original model
|
||||
seq_length = 64
|
||||
tok = Tokenizer(seq_length)
|
||||
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
|
||||
original_model.compile()
|
||||
original_model.load_weights(checkpoint_path)
|
||||
|
||||
tf_params = original_model.trainable_variables
|
||||
tf_non_train_params = original_model.non_trainable_variables
|
||||
tf_params = {param.name: param.numpy() for param in tf_params}
|
||||
for param in tf_non_train_params:
|
||||
tf_params[param.name] = param.numpy()
|
||||
tf_param_names = list(tf_params.keys())
|
||||
|
||||
# Load HuggingFace model
|
||||
config = get_align_config()
|
||||
hf_model = AlignModel(config).eval()
|
||||
hf_params = hf_model.state_dict()
|
||||
|
||||
# Create src-to-dst parameter name mapping dictionary
|
||||
print("Converting parameters...")
|
||||
key_mapping = rename_keys(tf_param_names)
|
||||
replace_params(hf_params, tf_params, key_mapping)
|
||||
|
||||
# Initialize processor
|
||||
processor = get_processor()
|
||||
inputs = processor(
|
||||
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
|
||||
)
|
||||
|
||||
# HF model inference
|
||||
hf_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
|
||||
hf_image_features = outputs.image_embeds.detach().numpy()
|
||||
hf_text_features = outputs.text_embeds.detach().numpy()
|
||||
|
||||
# Original model inference
|
||||
original_model.trainable = False
|
||||
tf_image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
|
||||
text = tok(tf.constant(["A picture of a cat"]))
|
||||
|
||||
image_features = original_model.image_encoder(image, training=False)
|
||||
text_features = original_model.text_encoder(text, training=False)
|
||||
|
||||
image_features = tf.nn.l2_normalize(image_features, axis=-1)
|
||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||
|
||||
# Check whether original and HF model outputs match -> np.allclose
|
||||
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||
raise ValueError("The predicted image features are not the same.")
|
||||
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||
raise ValueError("The predicted text features are not the same.")
|
||||
print("Model outputs match!")
|
||||
|
||||
if save_model:
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
# Save converted model and image processor
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Push model and image processor to hub
|
||||
print("Pushing converted ALIGN to the hub...")
|
||||
processor.push_to_hub("align-base")
|
||||
hf_model.push_to_hub("align-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="./weights/model-weights",
|
||||
type=str,
|
||||
help="Path to the pretrained TF ALIGN checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="hf_model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EPILOG_TXT = """Example:
|
||||
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
|
||||
|
||||
Example for creating the old state dict file with Python:
|
||||
|
||||
import torch
|
||||
from aria.model.language_model.aria_llama import AriaTextForCausalLM
|
||||
|
||||
# load model
|
||||
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
|
||||
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
# load vision tower
|
||||
model.get_vision_tower().load_model()
|
||||
|
||||
# Save state dict
|
||||
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
|
||||
"""
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"vision_tower.vision_model": "vision_tower",
|
||||
"ln_ffn": "layer_norm",
|
||||
"ffn": "feed_forward",
|
||||
"ln_kv": "layer_norm_kv",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
|
||||
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_model_id,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|img|>",
|
||||
"pad_token": "<pad>",
|
||||
},
|
||||
)
|
||||
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
|
||||
processor = AriaProcessor.from_pretrained(
|
||||
text_model_id,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(text_model_id)
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_index = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
|
||||
}
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AriaForConditionalGeneration(config)
|
||||
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
# print("Saving models")
|
||||
# model.save_pretrained("local_aria", safe_serialization=False)
|
||||
# processor.save_pretrained("local_aria")
|
||||
print("Pushing to hub")
|
||||
model.push_to_hub(output_hub_path, create_pr=True)
|
||||
processor.push_to_hub(output_hub_path, create_pr=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=EPILOG_TXT,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the text model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the vision model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the converted model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_state_dict_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,279 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_audio_spectrogram_transformer_config(model_name):
|
||||
config = ASTConfig()
|
||||
|
||||
if "10-10" in model_name:
|
||||
pass
|
||||
elif "speech-commands" in model_name:
|
||||
config.max_length = 128
|
||||
elif "12-12" in model_name:
|
||||
config.time_stride = 12
|
||||
config.frequency_stride = 12
|
||||
elif "14-14" in model_name:
|
||||
config.time_stride = 14
|
||||
config.frequency_stride = 14
|
||||
elif "16-16" in model_name:
|
||||
config.time_stride = 16
|
||||
config.frequency_stride = 16
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "speech-commands" in model_name:
|
||||
config.num_labels = 35
|
||||
filename = "speech-commands-v2-id2label.json"
|
||||
else:
|
||||
config.num_labels = 527
|
||||
filename = "audioset-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "module.v" in name:
|
||||
name = name.replace("module.v", "audio_spectrogram_transformer")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "dist_token" in name:
|
||||
name = name.replace("dist_token", "embeddings.distillation_token")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
# transformer blocks
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
# final layernorm
|
||||
if "audio_spectrogram_transformer.norm" in name:
|
||||
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
|
||||
# classifier head
|
||||
if "module.mlp_head.0" in name:
|
||||
name = name.replace("module.mlp_head.0", "classifier.layernorm")
|
||||
if "module.mlp_head.1" in name:
|
||||
name = name.replace("module.mlp_head.1", "classifier.dense")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.hidden_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
|
||||
] = val[:dim, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
|
||||
] = val[dim : dim * 2, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
|
||||
] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
|
||||
] = val[:dim]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
|
||||
] = val[dim : dim * 2]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
|
||||
] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_keys(state_dict):
|
||||
ignore_keys = [
|
||||
"module.v.head.weight",
|
||||
"module.v.head.bias",
|
||||
"module.v.head_dist.weight",
|
||||
"module.v.head_dist.bias",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
|
||||
"""
|
||||
config = get_audio_spectrogram_transformer_config(model_name)
|
||||
|
||||
model_name_to_url = {
|
||||
"ast-finetuned-audioset-10-10-0.4593": (
|
||||
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.450": (
|
||||
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448": (
|
||||
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448-v2": (
|
||||
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-12-12-0.447": (
|
||||
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-14-14-0.443": (
|
||||
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-16-16-0.442": (
|
||||
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-speech-commands-v2": (
|
||||
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
|
||||
),
|
||||
}
|
||||
|
||||
# load original state_dict
|
||||
checkpoint_url = model_name_to_url[model_name]
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove some keys
|
||||
remove_keys(state_dict)
|
||||
# rename some keys
|
||||
new_state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
# load 🤗 model
|
||||
model = ASTForAudioClassification(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify outputs on dummy input
|
||||
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
|
||||
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
|
||||
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
|
||||
max_length = 1024 if "speech-commands" not in model_name else 128
|
||||
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
|
||||
filename="sample_audio.flac",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
waveform, _ = torchaudio.load(filepath)
|
||||
waveform = waveform.squeeze().numpy()
|
||||
|
||||
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
if model_name == "ast-finetuned-audioset-10-10-0.4593":
|
||||
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.450":
|
||||
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448":
|
||||
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
|
||||
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
|
||||
elif model_name == "ast-finetuned-audioset-12-12-0.447":
|
||||
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
|
||||
elif model_name == "ast-finetuned-audioset-14-14-0.443":
|
||||
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
|
||||
elif model_name == "ast-finetuned-audioset-16-16-0.442":
|
||||
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
|
||||
elif model_name == "ast-finetuned-speech-commands-v2":
|
||||
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError("Logits don't match")
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"MIT/{model_name}")
|
||||
feature_extractor.push_to_hub(f"MIT/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ast-finetuned-audioset-10-10-0.4593",
|
||||
type=str,
|
||||
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -544,10 +544,6 @@ class _BaseAutoModelClass:
|
||||
if kwargs_orig.get("quantization_config", None) is not None:
|
||||
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
||||
|
||||
# AutoClass-specific config manipulation
|
||||
config = copy.deepcopy(config)
|
||||
config = cls._prepare_config_for_auto_class(config)
|
||||
|
||||
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
||||
has_local_code = type(config) in cls._model_mapping.keys()
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
@ -570,6 +566,8 @@ class _BaseAutoModelClass:
|
||||
)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
if model_class.config_class == config.sub_configs.get("text_config", None):
|
||||
config = config.get_text_config()
|
||||
return model_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||
)
|
||||
|
@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("levit", "LevitConfig"),
|
||||
("lilt", "LiltConfig"),
|
||||
("llama", "LlamaConfig"),
|
||||
("llama4", "Llama4Config"),
|
||||
("llama4_text", "Llama4TextConfig"),
|
||||
("llava", "LlavaConfig"),
|
||||
("llava_next", "LlavaNextConfig"),
|
||||
("llava_next_video", "LlavaNextVideoConfig"),
|
||||
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("llama", "LLaMA"),
|
||||
("llama2", "Llama2"),
|
||||
("llama3", "Llama3"),
|
||||
("llama4", "Llama4"),
|
||||
("llama4_text", "Llama4ForCausalLM"),
|
||||
("llava", "LLaVa"),
|
||||
("llava_next", "LLaVA-NeXT"),
|
||||
("llava_next_video", "LLaVa-NeXT-Video"),
|
||||
@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
("granitevision", "llava_next"),
|
||||
("sam_vision_model", "sam"),
|
||||
("llama4_text", "llama4"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -104,6 +104,7 @@ else:
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
|
||||
("levit", ("LevitImageProcessor",)),
|
||||
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
||||
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
||||
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
||||
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
|
||||
|
@ -17,7 +17,6 @@
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from .auto_factory import (
|
||||
_BaseAutoBackboneClass,
|
||||
@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("levit", "LevitModel"),
|
||||
("lilt", "LiltModel"),
|
||||
("llama", "LlamaModel"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("longformer", "LongformerModel"),
|
||||
("longt5", "LongT5Model"),
|
||||
("luke", "LukeModel"),
|
||||
@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("llama4", "Llama4ForCausalLM"),
|
||||
("llama4_text", "Llama4ForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("ijepa", "IJepaModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("levit", "LevitModel"),
|
||||
("llama4", "Llama4VisionModel"),
|
||||
("mllama", "MllamaVisionModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("idefics3", "Idefics3ForConditionalGeneration"),
|
||||
("instructblip", "InstructBlipForConditionalGeneration"),
|
||||
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3TextModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("llama4", "Llama4TextModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("mllama", "MllamaTextModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
@classmethod
|
||||
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""
|
||||
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
|
||||
a nested text decoder section, uses that section instead.
|
||||
|
||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||
"""
|
||||
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
if hasattr(config, text_config_name):
|
||||
text_config_names += [text_config_name]
|
||||
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||
warnings.warn(
|
||||
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||
FutureWarning,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("llama4", "Llama4Processor"),
|
||||
("llava", "LlavaProcessor"),
|
||||
("llava_next", "LlavaNextProcessor"),
|
||||
("llava_next_video", "LlavaNextVideoProcessor"),
|
||||
|
@ -292,6 +292,20 @@ else:
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"llama4",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"llama4_text",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
@ -1,273 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: Dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
)
|
@ -1,263 +0,0 @@
|
||||
"""Convert Bark checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from bark.generation import _load_model as _bark_load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import EncodecConfig, EncodecModel, set_seed
|
||||
from transformers.models.bark.configuration_bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
BarkFineConfig,
|
||||
BarkSemanticConfig,
|
||||
)
|
||||
from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkCoarseGenerationConfig,
|
||||
BarkFineGenerationConfig,
|
||||
BarkGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
set_seed(770)
|
||||
|
||||
|
||||
new_layer_name_dict = {
|
||||
"c_attn": "att_proj",
|
||||
"c_proj": "out_proj",
|
||||
"c_fc": "in_proj",
|
||||
"transformer.": "",
|
||||
"h.": "layers.",
|
||||
"ln_1": "layernorm_1",
|
||||
"ln_2": "layernorm_2",
|
||||
"ln_f": "layernorm_final",
|
||||
"wpe": "position_embeds_layer",
|
||||
"wte": "input_embeds_layer",
|
||||
}
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text.pt",
|
||||
},
|
||||
"coarse_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse.pt",
|
||||
},
|
||||
"fine_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine.pt",
|
||||
},
|
||||
"text": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text_2.pt",
|
||||
},
|
||||
"coarse": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse_2.pt",
|
||||
},
|
||||
"fine": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine_2.pt",
|
||||
},
|
||||
}
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
key = model_type
|
||||
if use_small:
|
||||
key += "_small"
|
||||
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if model_type == "text":
|
||||
ModelClass = BarkSemanticModel
|
||||
ConfigClass = BarkSemanticConfig
|
||||
GenerationConfigClass = BarkSemanticGenerationConfig
|
||||
elif model_type == "coarse":
|
||||
ModelClass = BarkCoarseModel
|
||||
ConfigClass = BarkCoarseConfig
|
||||
GenerationConfigClass = BarkCoarseGenerationConfig
|
||||
elif model_type == "fine":
|
||||
ModelClass = BarkFineModel
|
||||
ConfigClass = BarkFineConfig
|
||||
GenerationConfigClass = BarkFineGenerationConfig
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"])
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
|
||||
# convert Bark model arguments to HF Bark model arguments
|
||||
model_args["num_heads"] = model_args.pop("n_head")
|
||||
model_args["hidden_size"] = model_args.pop("n_embd")
|
||||
model_args["num_layers"] = model_args.pop("n_layer")
|
||||
|
||||
model_config = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(config=model_config)
|
||||
model_generation_config = GenerationConfigClass()
|
||||
|
||||
model.generation_config = model_generation_config
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
# replace part of the key with corresponding layer name in HF implementation
|
||||
new_k = k[len(unwanted_prefix) :]
|
||||
for old_layer_name in new_layer_name_dict:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.num_parameters(exclude_embeddings=True)
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
|
||||
device = "cpu" # do conversion on cpu
|
||||
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
|
||||
|
||||
# load bark initial model
|
||||
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
|
||||
|
||||
if model_type == "text":
|
||||
bark_model = bark_model["model"]
|
||||
|
||||
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
|
||||
raise ValueError("initial and new models don't have the same number of parameters")
|
||||
|
||||
# check if same output as the bark model
|
||||
batch_size = 5
|
||||
sequence_length = 10
|
||||
|
||||
if model_type in ["text", "coarse"]:
|
||||
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
|
||||
output_old_model = bark_model(vec)[0]
|
||||
|
||||
output_new_model_total = model(vec)
|
||||
|
||||
# take last logits
|
||||
output_new_model = output_new_model_total.logits[:, [-1], :]
|
||||
|
||||
else:
|
||||
prediction_codebook_channel = 3
|
||||
n_codes_total = 8
|
||||
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
|
||||
|
||||
output_new_model_total = model(prediction_codebook_channel, vec)
|
||||
output_old_model = bark_model(prediction_codebook_channel, vec)
|
||||
|
||||
output_new_model = output_new_model_total.logits
|
||||
|
||||
# output difference should come from the difference of self-attention implementation design
|
||||
if output_new_model.shape != output_old_model.shape:
|
||||
raise ValueError("initial and new outputs don't have the same shape")
|
||||
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
|
||||
raise ValueError("initial and new outputs are not equal")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_whole_bark_model(
|
||||
semantic_path,
|
||||
coarse_path,
|
||||
fine_path,
|
||||
append_text,
|
||||
hub_path,
|
||||
folder_path,
|
||||
):
|
||||
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
|
||||
|
||||
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
|
||||
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
|
||||
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
|
||||
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
semantic = BarkSemanticModel.from_pretrained(semantic_path)
|
||||
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
|
||||
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
|
||||
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
bark_config = BarkConfig.from_sub_model_configs(
|
||||
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
|
||||
)
|
||||
|
||||
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
|
||||
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
|
||||
)
|
||||
|
||||
bark = BarkModel(bark_config)
|
||||
|
||||
bark.semantic = semantic
|
||||
bark.coarse_acoustics = coarseAcoustic
|
||||
bark.fine_acoustics = fineAcoustic
|
||||
bark.codec_model = codec
|
||||
|
||||
bark.generation_config = bark_generation_config
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def load_xsum_checkpoint(checkpoint_path):
|
||||
"""Checkpoint path should end in model.pt"""
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
||||
hub_interface.model.load_state_dict(sd["model"])
|
||||
return hub_interface
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
||||
else:
|
||||
bart = load_xsum_checkpoint(checkpoint_path)
|
||||
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
if hf_checkpoint_name is None:
|
||||
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
if not torch.eq(tokens, tokens2).all():
|
||||
raise ValueError(
|
||||
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||
)
|
||||
|
||||
if checkpoint_path == "bart.large.mnli":
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
new_model_outputs = model(tokens)[0] # logits
|
||||
else: # no classification heads to worry about
|
||||
state_dict = bart.model.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
else:
|
||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
new_model_outputs = model.model(tokens)[0]
|
||||
|
||||
# Check results
|
||||
if fairseq_output.shape != new_model_outputs.shape:
|
||||
raise ValueError(
|
||||
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||
)
|
||||
if (fairseq_output != new_model_outputs).any().item():
|
||||
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||
)
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
@ -1,373 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BEiT checkpoints from the unilm repository."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitImageProcessor,
|
||||
)
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + shared relative position bias + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
(
|
||||
"rel_pos_bias.relative_position_bias_table",
|
||||
"beit.encoder.relative_position_bias.relative_position_bias_table",
|
||||
),
|
||||
(
|
||||
"rel_pos_bias.relative_position_index",
|
||||
"beit.encoder.relative_position_bias.relative_position_index",
|
||||
),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
# masked image modeling
|
||||
config.use_shared_relative_position_bias = True
|
||||
config.use_mask_token = True
|
||||
has_lm_head = True
|
||||
elif checkpoint_url[-9:-4] == "ft22k":
|
||||
# intermediate fine-tuning on ImageNet-22k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
elif checkpoint_url[-8:-4] == "to1k":
|
||||
# fine-tuning on ImageNet-1k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
if "384" in checkpoint_url:
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
pass
|
||||
elif "large" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
|
||||
)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size([1, 1000])
|
||||
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
|
||||
expected_class_idx = 2397
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
|
||||
expected_class_idx = 2396
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
|
||||
expected_class_idx = 285
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
|
||||
expected_class_idx = 281
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
|
||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
|
||||
|
||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
|
||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
|
||||
|
||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||
|
||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
|
||||
Models trained with never versions are not compatible with this script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
layer_depth = []
|
||||
for full_name, shape in init_vars:
|
||||
# logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
name = full_name.split("/")
|
||||
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
|
||||
logger.info(f"Skipping non-model layer {full_name}")
|
||||
continue
|
||||
if "optimizer" in full_name:
|
||||
logger.info(f"Skipping optimization layer {full_name}")
|
||||
continue
|
||||
if name[0] == "model":
|
||||
# ignore initial 'model'
|
||||
name = name[1:]
|
||||
# figure out how many levels deep the name is
|
||||
depth = 0
|
||||
for _name in name:
|
||||
if _name.startswith("layer_with_weights"):
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
layer_depth.append(depth)
|
||||
# read data
|
||||
array = tf.train.load_variable(tf_path, full_name)
|
||||
names.append("/".join(name))
|
||||
arrays.append(array)
|
||||
logger.info(f"Read a total of {len(arrays):,} layers")
|
||||
|
||||
# Sanity check
|
||||
if len(set(layer_depth)) != 1:
|
||||
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
|
||||
layer_depth = list(set(layer_depth))[0]
|
||||
if layer_depth != 1:
|
||||
raise ValueError(
|
||||
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
|
||||
" heads."
|
||||
)
|
||||
|
||||
# convert layers
|
||||
logger.info("Converting weights...")
|
||||
for full_name, array in zip(names, arrays):
|
||||
name = full_name.split("/")
|
||||
pointer = model
|
||||
trace = []
|
||||
for i, m_name in enumerate(name):
|
||||
if m_name == ".ATTRIBUTES":
|
||||
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
|
||||
break
|
||||
if m_name.startswith("layer_with_weights"):
|
||||
layer_num = int(m_name.split("-")[-1])
|
||||
if layer_num <= 2:
|
||||
# embedding layers
|
||||
# layer_num 0: word_embeddings
|
||||
# layer_num 1: position_embeddings
|
||||
# layer_num 2: token_type_embeddings
|
||||
continue
|
||||
elif layer_num == 3:
|
||||
# embedding LayerNorm
|
||||
trace.extend(["embeddings", "LayerNorm"])
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
|
||||
# encoder layers
|
||||
trace.extend(["encoder", "layer", str(layer_num - 4)])
|
||||
pointer = getattr(pointer, "encoder")
|
||||
pointer = getattr(pointer, "layer")
|
||||
pointer = pointer[layer_num - 4]
|
||||
elif layer_num == config.num_hidden_layers + 4:
|
||||
# pooler layer
|
||||
trace.extend(["pooler", "dense"])
|
||||
pointer = getattr(pointer, "pooler")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "embeddings":
|
||||
trace.append("embeddings")
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
if layer_num == 0:
|
||||
trace.append("word_embeddings")
|
||||
pointer = getattr(pointer, "word_embeddings")
|
||||
elif layer_num == 1:
|
||||
trace.append("position_embeddings")
|
||||
pointer = getattr(pointer, "position_embeddings")
|
||||
elif layer_num == 2:
|
||||
trace.append("token_type_embeddings")
|
||||
pointer = getattr(pointer, "token_type_embeddings")
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding layer with name {full_name}")
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "_attention_layer":
|
||||
# self-attention layer
|
||||
trace.extend(["attention", "self"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "self")
|
||||
elif m_name == "_attention_layer_norm":
|
||||
# output attention norm
|
||||
trace.extend(["attention", "output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_attention_output_dense":
|
||||
# output attention dense
|
||||
trace.extend(["attention", "output", "dense"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_dense":
|
||||
# output dense
|
||||
trace.extend(["output", "dense"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output dense
|
||||
trace.extend(["output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_key_dense":
|
||||
# attention key
|
||||
trace.append("key")
|
||||
pointer = getattr(pointer, "key")
|
||||
elif m_name == "_query_dense":
|
||||
# attention query
|
||||
trace.append("query")
|
||||
pointer = getattr(pointer, "query")
|
||||
elif m_name == "_value_dense":
|
||||
# attention value
|
||||
trace.append("value")
|
||||
pointer = getattr(pointer, "value")
|
||||
elif m_name == "_intermediate_dense":
|
||||
# attention intermediate dense
|
||||
trace.extend(["intermediate", "dense"])
|
||||
pointer = getattr(pointer, "intermediate")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output layer norm
|
||||
trace.append("output")
|
||||
pointer = getattr(pointer, "output")
|
||||
# weights & biases
|
||||
elif m_name in ["bias", "beta"]:
|
||||
trace.append("bias")
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif m_name in ["kernel", "gamma"]:
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
logger.warning(f"Ignored {m_name}")
|
||||
# for certain layers reshape is necessary
|
||||
trace = ".".join(trace)
|
||||
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
|
||||
r"(\S+)\.attention\.output\.dense\.weight", trace
|
||||
):
|
||||
array = array.reshape(pointer.data.shape)
|
||||
if "kernel" in full_name:
|
||||
array = array.transpose()
|
||||
if pointer.shape == array.shape:
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
|
||||
f" {array.shape}"
|
||||
)
|
||||
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
|
||||
return model
|
||||
|
||||
|
||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
|
||||
# Instantiate model
|
||||
logger.info(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from checkpoint
|
||||
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
|
||||
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
|
||||
|
||||
# Save pytorch-model
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model (must include filename).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,112 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model: BertModel Pytorch model instance to be converted
|
||||
ckpt_dir: Tensorflow model directory
|
||||
model_name: model name
|
||||
|
||||
Currently supported HF models:
|
||||
|
||||
- Y BertModel
|
||||
- N BertForMaskedLM
|
||||
- N BertForPreTraining
|
||||
- N BertForMultipleChoice
|
||||
- N BertForNextSentencePrediction
|
||||
- N BertForSequenceClassification
|
||||
- N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return f"bert/{name}"
|
||||
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
session.run(tf_var)
|
||||
return tf_var
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Session() as session:
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any(x in var_name for x in tensors_to_transpose):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
|
||||
tf_weight = session.run(tf_var)
|
||||
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
|
||||
|
||||
saver = tf.train.Saver(tf.trainable_variables())
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,188 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
|
||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
|
||||
|
||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertPooler,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
|
||||
def get_masked_lm_array(name: str):
|
||||
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_array(name: str):
|
||||
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_layer_array(layer_index: int, name: str):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
array = array.reshape(orginal_shape)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
print(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
# Layers
|
||||
for layer_index in range(0, config.num_hidden_layers):
|
||||
layer: BertLayer = model.bert.encoder.layer[layer_index]
|
||||
|
||||
# Self-attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
|
||||
self_attn.query.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
|
||||
)
|
||||
self_attn.query.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
|
||||
)
|
||||
self_attn.key.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
|
||||
)
|
||||
self_attn.key.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
|
||||
)
|
||||
self_attn.value.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
|
||||
)
|
||||
self_attn.value.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
|
||||
)
|
||||
|
||||
# Self-attention Output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
|
||||
self_output.dense.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
|
||||
)
|
||||
self_output.dense.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
|
||||
)
|
||||
|
||||
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
|
||||
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
|
||||
|
||||
# Intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
|
||||
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
|
||||
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
|
||||
|
||||
# Output
|
||||
bert_output: BertOutput = layer.output
|
||||
|
||||
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
|
||||
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
|
||||
|
||||
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
|
||||
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
|
||||
|
||||
# Embeddings
|
||||
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
|
||||
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
|
||||
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
|
||||
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
|
||||
|
||||
# LM Head
|
||||
lm_head = model.cls.predictions.transform
|
||||
|
||||
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
|
||||
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
|
||||
|
||||
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
|
||||
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
|
||||
|
||||
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
|
||||
|
||||
# Pooling
|
||||
model.bert.pooler = BertPooler(config=config)
|
||||
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
|
||||
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
|
||||
|
||||
# Export final model
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Integration test - should load without any errors ;)
|
||||
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
|
||||
print(new_model.eval())
|
||||
|
||||
print("Model conversion was done sucessfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -608,6 +608,8 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BertGenerationOnlyLMHead):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
BERT_GENERATION_START_DOCSTRING = r"""
|
||||
|
@ -1,69 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigBird checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
|
||||
# Initialise PyTorch model
|
||||
config = BigBirdConfig.from_json_file(big_bird_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
if is_trivia_qa:
|
||||
model = BigBirdForQuestionAnswering(config)
|
||||
else:
|
||||
model = BigBirdForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big_bird_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
|
||||
)
|
@ -1769,6 +1769,8 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BigBirdLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
BIG_BIRD_START_DOCSTRING = r"""
|
||||
|
@ -1,170 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
|
||||
|
||||
|
||||
INIT_COMMON = [
|
||||
# tf -> hf
|
||||
("/", "."),
|
||||
("layer_", "layers."),
|
||||
("kernel", "weight"),
|
||||
("beta", "bias"),
|
||||
("gamma", "weight"),
|
||||
("pegasus", "model"),
|
||||
]
|
||||
END_COMMON = [
|
||||
(".output.dense", ".fc2"),
|
||||
("intermediate.LayerNorm", "final_layer_norm"),
|
||||
("intermediate.dense", "fc1"),
|
||||
]
|
||||
|
||||
DECODER_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.out_proj"),
|
||||
("attention.self", "self_attn"),
|
||||
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
|
||||
("attention.encdec_output.dense", "encoder_attn.out_proj"),
|
||||
("attention.encdec", "encoder_attn"),
|
||||
("key", "k_proj"),
|
||||
("value", "v_proj"),
|
||||
("query", "q_proj"),
|
||||
("decoder.LayerNorm", "decoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
REMAINING_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("embeddings.word_embeddings", "shared.weight"),
|
||||
("embeddings.position_embeddings", "embed_positions.weight"),
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.output"),
|
||||
("attention.self", "self_attn.self"),
|
||||
("encoder.LayerNorm", "encoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
KEYS_TO_IGNORE = [
|
||||
"encdec/key/bias",
|
||||
"encdec/query/bias",
|
||||
"encdec/value/bias",
|
||||
"self/key/bias",
|
||||
"self/query/bias",
|
||||
"self/value/bias",
|
||||
"encdec_output/dense/bias",
|
||||
"attention/output/dense/bias",
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k, patterns):
|
||||
for tf_name, hf_name in patterns:
|
||||
k = k.replace(tf_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
|
||||
cfg = BigBirdPegasusConfig(**config_update)
|
||||
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
|
||||
state_dict = torch_model.state_dict()
|
||||
mapping = {}
|
||||
|
||||
# separating decoder weights
|
||||
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
|
||||
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
|
||||
|
||||
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = DECODER_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = REMAINING_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
|
||||
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
|
||||
missing, extra = torch_model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if k
|
||||
not in [
|
||||
"final_logits_bias",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path) -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any(pat in name for pat in ignore_name)
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
config_update = {}
|
||||
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BioGptConfig, BioGptForCausalLM
|
||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
|
||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(f, "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
|
||||
# prep
|
||||
if not os.path.exists(biogpt_checkpoint_path):
|
||||
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
|
||||
chkpt = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
args = chkpt["cfg"]["model"]
|
||||
|
||||
# dicts
|
||||
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
|
||||
if not os.path.isfile(dict_file):
|
||||
raise ValueError(f"path to the file {dict_file} does not exist!")
|
||||
src_dict = Dictionary.load(dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
|
||||
if not os.path.isfile(bpecodes_file):
|
||||
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
|
||||
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
shutil.copyfile(bpecodes_file, merges_file)
|
||||
|
||||
# model config
|
||||
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
model_conf = {
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"architectures": ["BioGptForCausalLM"],
|
||||
"attention_probs_dropout_prob": args["attention_dropout"],
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": args["activation_fn"],
|
||||
"hidden_dropout_prob": args["dropout"],
|
||||
"hidden_size": args["decoder_embed_dim"],
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": args["decoder_ffn_embed_dim"],
|
||||
"layer_norm_eps": 1e-12,
|
||||
"layerdrop": args["decoder_layerdrop"],
|
||||
"max_position_embeddings": args["max_target_positions"],
|
||||
"model_type": "biogpt",
|
||||
"num_attention_heads": args["decoder_attention_heads"],
|
||||
"num_hidden_layers": args["decoder_layers"],
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_decoder_input_output_embed"],
|
||||
"vocab_size": src_vocab_size,
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
|
||||
print(f"Generating {biogpt_model_config_file}")
|
||||
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": "<pad>",
|
||||
"special_tokens_map_file": None,
|
||||
"tokenizer_class": "BioGptTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
print(f"Generating {biogpt_tokenizer_config_file}")
|
||||
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model_state_dict = chkpt["model"]
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"decoder.version",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
layer_names = list(model_state_dict.keys())
|
||||
for layer_name in layer_names:
|
||||
if layer_name.endswith("output_projection.weight"):
|
||||
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
|
||||
else:
|
||||
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
|
||||
|
||||
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = BioGptForCausalLM(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--biogpt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
|
||||
" bpecodes, etc."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,177 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BiT checkpoints from the timm library."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
|
||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
conv_layer = "std_conv" if "bit" in model_name else False
|
||||
|
||||
# note that when using BiT as backbone for ViT-hybrid checkpoints,
|
||||
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
|
||||
# config.conv_layer = "std_conv_same"
|
||||
config = BitConfig(
|
||||
conv_layer=conv_layer,
|
||||
num_labels=1000,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "stem.conv" in name:
|
||||
name = name.replace("stem.conv", "bit.embedder.convolution")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "layers")
|
||||
if "head.fc" in name:
|
||||
name = name.replace("head.fc", "classifier.1")
|
||||
if name.startswith("norm"):
|
||||
name = "bit." + name
|
||||
if "bit" not in name and "classifier" not in name:
|
||||
name = "bit.encoder." + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BiT structure.
|
||||
"""
|
||||
|
||||
# define default BiT configuration
|
||||
config = get_config(model_name)
|
||||
|
||||
# load original model from timm
|
||||
timm_model = create_model(model_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model
|
||||
state_dict = timm_model.state_dict()
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
|
||||
|
||||
# load HuggingFace model
|
||||
model = BitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create image processor
|
||||
transform = create_transform(**resolve_data_config({}, model=timm_model))
|
||||
timm_transforms = transform.transforms
|
||||
|
||||
pillow_resamplings = {
|
||||
"bilinear": PILImageResampling.BILINEAR,
|
||||
"bicubic": PILImageResampling.BICUBIC,
|
||||
"nearest": PILImageResampling.NEAREST,
|
||||
}
|
||||
|
||||
processor = BitImageProcessor(
|
||||
do_resize=True,
|
||||
size={"shortest_edge": timm_transforms[0].size},
|
||||
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
|
||||
do_center_crop=True,
|
||||
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
|
||||
do_normalize=True,
|
||||
image_mean=timm_transforms[-1].mean.tolist(),
|
||||
image_std=timm_transforms[-1].std.tolist(),
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
timm_pixel_values = transform(image).unsqueeze(0)
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# verify pixel values
|
||||
assert torch.allclose(timm_pixel_values, pixel_values)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Logits:", logits[0, :3])
|
||||
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model {model_name} and processor to the hub")
|
||||
model.push_to_hub(f"ybelkada/{model_name}")
|
||||
processor.push_to_hub(f"ybelkada/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="resnetv2_50x1_bitm",
|
||||
type=str,
|
||||
help="Name of the BiT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,114 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -1,191 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# git clone https://github.com/salesforce/BLIP.git
|
||||
from models.blip import blip_decoder
|
||||
from models.blip_itm import blip_itm
|
||||
from models.blip_vqa import blip_vqa
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BlipConfig,
|
||||
BlipForConditionalGeneration,
|
||||
BlipForImageTextRetrieval,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
def load_demo_image(image_size, device):
|
||||
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
image = transform(raw_image).unsqueeze(0).to(device)
|
||||
return image
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
if "visual_encoder" in key:
|
||||
key = re.sub("visual_encoder*", "vision_model.encoder", key)
|
||||
if "blocks" in key:
|
||||
key = re.sub(r"blocks", "layers", key)
|
||||
if "attn" in key:
|
||||
key = re.sub(r"attn", "self_attn", key)
|
||||
if "norm1" in key:
|
||||
key = re.sub(r"norm1", "layer_norm1", key)
|
||||
if "norm2" in key:
|
||||
key = re.sub(r"norm2", "layer_norm2", key)
|
||||
if "encoder.norm" in key:
|
||||
key = re.sub(r"encoder.norm", "post_layernorm", key)
|
||||
if "encoder.patch_embed.proj" in key:
|
||||
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
|
||||
|
||||
if "encoder.pos_embed" in key:
|
||||
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
|
||||
if "encoder.cls_token" in key:
|
||||
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
|
||||
|
||||
if "self_attn" in key:
|
||||
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = BlipConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = BlipForConditionalGeneration(config).eval()
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
|
||||
|
||||
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
modified_state_dict = pt_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_model.load_state_dict(modified_state_dict)
|
||||
|
||||
image_size = 384
|
||||
image = load_demo_image(image_size=image_size, device="cpu")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
input_ids = tokenizer(["a picture of"]).input_ids
|
||||
|
||||
out = hf_model.generate(image, input_ids)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
out = hf_model.generate(image)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
|
||||
model_url = (
|
||||
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
|
||||
)
|
||||
|
||||
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
|
||||
vqa_model.eval()
|
||||
|
||||
modified_state_dict = vqa_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_vqa_model = BlipForQuestionAnswering(config)
|
||||
|
||||
hf_vqa_model.load_state_dict(modified_state_dict)
|
||||
|
||||
question = ["How many dogs are in this image?"]
|
||||
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
|
||||
|
||||
answer = hf_vqa_model.generate(question_input_ids, image)
|
||||
print(tokenizer.decode(answer[0]))
|
||||
|
||||
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
|
||||
|
||||
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
|
||||
itm_model.eval()
|
||||
|
||||
modified_state_dict = itm_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_itm_model = BlipForImageTextRetrieval(config)
|
||||
|
||||
question = ["A picture of a woman with a dog sitting in a beach"]
|
||||
question_input_ids = tokenizer(
|
||||
question,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
).input_ids
|
||||
|
||||
hf_itm_model.load_state_dict(modified_state_dict)
|
||||
hf_itm_model.eval()
|
||||
|
||||
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
|
||||
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
|
||||
|
||||
assert out[0].item() == 0.2110687494277954
|
||||
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
|
@ -1,390 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert BLIP-2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
Blip2Config,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Processor,
|
||||
Blip2QFormerConfig,
|
||||
Blip2VisionConfig,
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, model_name):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||
if "itm" in model_name:
|
||||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name, eos_token_id):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "opt-2.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "opt-6.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "itm" in model_name:
|
||||
text_config = {}
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
if "itm" in model_name:
|
||||
config = Blip2Config(
|
||||
vision_config=vision_config,
|
||||
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||
)
|
||||
else:
|
||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(
|
||||
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
if "opt" in model_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||
elif "itm" in model_name:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
|
||||
if "itm" in model_name:
|
||||
eos_token_id = None
|
||||
else:
|
||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||
|
||||
if "itm" in model_name:
|
||||
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||
else:
|
||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
|
||||
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
|
||||
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
|
||||
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
|
||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config, model_name)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "opt_proj" in key:
|
||||
key = key.replace("opt_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("opt"):
|
||||
key = key.replace("opt", "language")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert len(missing_keys) == 0
|
||||
|
||||
if "itm" in model_name:
|
||||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||
else:
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
|
||||
if "itm" in model_name:
|
||||
caption = "a large fountain spewing water into the air"
|
||||
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
|
||||
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=False,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
else:
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
hf_model.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
"blip2-itm-vit-g",
|
||||
"blip2-itm-vit-g-coco",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="blip2-opt-2.7b",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
parser.add_argument(
|
||||
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(
|
||||
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||
)
|
@ -1,254 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print("Processing file: {}".format(file))
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors.keys():
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
@ -1,145 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Bros checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import bros # original repo
|
||||
import torch
|
||||
|
||||
from transformers import BrosConfig, BrosModel, BrosProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_configs(model_name):
|
||||
bros_config = BrosConfig.from_pretrained(model_name)
|
||||
return bros_config
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"embeddings.bbox_sinusoid_emb.inv_freq",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if name == "embeddings.bbox_projection.weight":
|
||||
name = "bbox_embeddings.bbox_projection.weight"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
# rename keys
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
# remove ignore keys
|
||||
remove_ignore_keys_(orig_state_dict)
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
# load original model
|
||||
original_model = bros.BrosModel.from_pretrained(model_name).eval()
|
||||
|
||||
# load HuggingFace Model
|
||||
bros_config = get_configs(model_name)
|
||||
model = BrosModel.from_pretrained(model_name, config=bros_config)
|
||||
model.eval()
|
||||
|
||||
state_dict = original_model.state_dict()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify results
|
||||
|
||||
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
|
||||
bbox = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
processor = BrosProcessor.from_pretrained(model_name)
|
||||
|
||||
encoding = processor("His name is Rocco.", return_tensors="pt")
|
||||
encoding["bbox"] = bbox
|
||||
|
||||
original_hidden_states = original_model(**encoding).last_hidden_state
|
||||
# pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
last_hidden_states = model(**encoding).last_hidden_state
|
||||
|
||||
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="jinho8345/bros-base-uncased",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Name of the original model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,59 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert T5 checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user